Thursday, August 18, 2022
No menu items!
HomeArtificial Intelligence and Machine LearningPython Classes and Their Use in Keras

Python Classes and Their Use in Keras



Last Updated on December 24, 2021

Classes are one of the fundamental building blocks of the Python language, which may be applied in the development of machine learning applications. As we shall be seeing, the Python syntax for developing classes is simple, and can be applied to implement callbacks in Keras. 

In this tutorial, you will discover the Python classes and their functionality. 

After completing this tutorial, you will know:

Why Python classes are important.
How to define and instantiate a class, and set its attributes. 
How to create methods and pass arguments. 
What is class inheritance. 
How to use classes to implement callbacks in Keras. 

Let’s get started. 

Python Classes and Their Use in Keras
Photo by S Migaj, some rights reserved.

Tutorial Overview

This tutorial is divided into six parts; they are:

Introduction to Classes
Defining a Class
Instantiation and Attribute References
Creating Methods and Passing Arguments
Class Inheritance
Using Classes in Keras

Introduction to Classes

In object-oriented languages, such as Python, classes are one of the fundamental building blocks. 

They can be likened to blueprints for an object, as they define what properties and methods/behaviors an object should have.

Python Fundamentals, 2018.

Creating a new class creates a new object, where every class instance can be characterized by its attributes to maintain its state, and methods to modify its state.

Defining a Class

The class keyword allows for the creation of a new class definition, immediately followed by the class name:

class MyClass:
<statements>

In this manner, a new class object bound to the specified class name (MyClass, in this particular case) is created. Each class object can support instantiation and attribute references, as we will see shortly.

Instantiation and Attribute References

Instantiation is the creation of a new instance of a class.

To create a new instance of a class, we can call it using its class name and assign it to a variable. This will create a new, empty class object:

x = MyClass()

Upon creating a new instance of a class, Python calls its object constructor method, __init()__, which often takes arguments that are used to set the instantiated object’s attributes. 

We can define this constructor method in our class just like a function and specify attributes that will need to be passed in when instantiating an object.

Python Fundamentals, 2018.

Let’s say, for instance, that we would like to define a new class named, Dog:

class Dog:
family = “Canine”

def __init__(self, name, breed):
self.name = name
self.breed = breed

Here, the constructor method takes two arguments, name and breed, which can be passed to it upon instantiating the object:

dog1 = Dog(“Lassie”, “Rough Collie”)

In the example that we are considering, name and breed are known as instance variables (or attributes), because they are bound to a specific instance. This means that such attributes belong only to the object in which they have been set, but not to any other object instantiated from the same class. 

On the other hand, family is a class variable (or attribute), because it is shared by all instances of the same class.

You may also note that the first argument of the constructor method (or any other method) is often called self. This argument refers to the object that we are in the process of creating. It is good practice to follow the convention of setting the first argument to self, to ensure the readability of your code for other programmers. 

Once we have set our object’s attributes, they can be accessed using the dot operator. For example, considering again the dog1 instance of the Dog class, its name attribute may be accessed as follows:

print(dog1.name)

Producing the following output:

Lassie

Creating Methods and Passing Arguments

In addition to having a constructor method, a class object can also have several other methods for modifying its state. 

The syntax for defining an instance method is familiar. We pass the argument self … It is always the first argument of an instance method.

Python Fundamentals, 2018.

Similar to the constructor method, each instance method can take several arguments, with the first one being the argument self that lets us set and access the object’s attributes:

class Dog:
family = “Canine”

def __init__(self, name, breed):
self.name = name
self.breed = breed

def info(self):
print(self.name, “is a female”, self.breed)

Different methods of the same object can also use the self argument to call each other:

class Dog:
family = “Canine”

def __init__(self, name, breed):
self.name = name
self.breed = breed
self.tricks = []

def add_tricks(self, x):
self.tricks.append(x)

def info(self, x):
self.add_tricks(x)
print(self.name, “is a female”, self.breed, “that”, self.tricks[0])

An output string can then be generated as follows:

dog1 = Dog(“Lassie”, “Rough Collie”)
dog1.info(“barks on command”)

We find that, in doing so, the barks on command input is appended to the tricks list when the info() method calls the add_tricks() method. The following output is produced:

Lassie is a female Rough Collie that barks on command

Class Inheritance

Another feature that Python supports is class inheritance. 

Inheritance is a mechanism that allows a subclass (also known as a derived or child class) to access all attributes and methods of a superclass (also known as a base or parent class). 

The syntax for using a subclass is the following:

class SubClass(BaseClass):
    <statements>

It is also possible that a subclass inherits from multiple base classes, too. In this case, the syntax would be as follows:

class SubClass(BaseClass1, BaseClass2, BaseClass3):
    <statements>

Class attributes and methods are searched for in the base class, and also in subsequent base classes in the case of multiple inheritance. 

Python further allows that a method in a subclass overrides another method in the base class that carries the same name. An overriding method in the subclass may be replacing the base class method, or simply extending its capabilities. When an overriding subclass method is available, it is this method that is executed when called, rather than the method with the same name in the base class. 

Using Classes in Keras

A practical use of classes in Keras is to write one’s own callbacks. 

A callback is a powerful tool in Keras that allows us to have a look at the behaviour of our model during the different stages of training, testing and prediction. 

Indeed, we may pass a list of callbacks to any of the following:

keras.Model.fit()
keras.Model.evaluate()
keras.Model.predict()

The Keras API comes with several built-in callbacks. Nonetheless, we might wish to write our own and, for this purpose, we shall be seeing how to build a custom callback class. In order to do so, we can inherit several methods from the callback base class, which can provide us with information of when:

Training, testing and prediction starts and ends. 
An epoch starts and ends. 
A training, testing and prediction batch starts and ends. 

Let’s first consider a simple example of a custom callback that reports back every time that an epoch starts and ends. We will name this custom callback class, EpochCallback, and override the epoch-level methods, on_epoch_begin() and on_epoch_end(), from the base class, keras.callbacks.Callback:

import tensorflow.keras as keras

class EpochCallback(keras.callbacks.Callback):
def on_epoch_begin(self, epoch, logs=None):
print(“Starting epoch {}”.format(epoch + 1))

def on_epoch_end(self, epoch, logs=None):
print(“Finished epoch {}”.format(epoch + 1))

In order to test the custom callback that we have just defined, we need a model to train. For this purpose, let’s define a simple Keras model:

from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Flatten

def simple_model():
model = Sequential()
model.add(Flatten(input_shape=(28, 28)))
model.add(Dense(128, activation=”relu”))
model.add(Dense(10, activation=”softmax”))

model.compile(loss=”categorical_crossentropy”,
optimizer=”sgd”,
metrics=[“accuracy”])
return model

We also need a dataset to train on, for which purpose we will be using the MNIST dataset:

from tensorflow.keras.datasets import mnist
from tensorflow.keras.utils import to_categorical

# Loading the MNIST training and testing data splits
(x_train, y_train), (x_test, y_test) = mnist.load_data()

# Pre-processing the training data
x_train = x_train / 255.0
x_train = x_train.reshape(60000, 28, 28, 1)
y_train_cat = to_categorical(y_train, 10)

Now, let’s try out the custom callback by adding it to the list of callbacks that we pass as input to the keras.Model.fit() method:

model = simple_model()

model.fit(x_train,
y_train_cat,
batch_size=32,
epochs=5,
callbacks=[EpochCallback()],
verbose=0)

 

The callback that we have just created produces the following output:

Starting epoch 1
Finished epoch 1
Starting epoch 2
Finished epoch 2
Starting epoch 3
Finished epoch 3
Starting epoch 4
Finished epoch 4
Starting epoch 5
Finished epoch 5

We can create another custom callback that monitors the loss value at the end of each epoch, and stores the model weights only if the loss has decreased. To this end, we will be reading the loss value from the log dict, which stores the metrics at the end of each batch and epoch. We will also be accessing the model corresponding to the current round of training, testing or prediction, by means of self.model. 

Let’s call this custom callback, CheckpointCallback:

import numpy as np

class CheckpointCallback(keras.callbacks.Callback):

def __init__(self):
super(CheckpointCallback, self).__init__()
self.best_weights = None

def on_train_begin(self, logs=None):
self.best_loss = np.Inf

def on_epoch_end(self, epoch, logs=None):
current_loss = logs.get(“loss”)
print(“Current loss is {}”.format(current_loss))
if np.less(current_loss, self.best_loss):
self.best_loss = current_loss
self.best_weights = self.model.get_weights()
print(“Storing the model weights at epoch {} n”.format(epoch + 1))

We can try this out again, this time including the CheckpointCallback into the list of callbacks:

model = simple_model()

model.fit(x_train,
y_train_cat,
batch_size=32,
epochs=5,
callbacks=[EpochCallback(), CheckpointCallback()],
verbose=0)

The following output of the two callbacks together is now produced:

Starting epoch 1
Finished epoch 1
Current loss is 0.6327750086784363
Storing the model weights at epoch 1

Starting epoch 2
Finished epoch 2
Current loss is 0.3391888439655304
Storing the model weights at epoch 2

Starting epoch 3
Finished epoch 3
Current loss is 0.29216915369033813
Storing the model weights at epoch 3

Starting epoch 4
Finished epoch 4
Current loss is 0.2625095248222351
Storing the model weights at epoch 4

Starting epoch 5
Finished epoch 5
Current loss is 0.23906977474689484
Storing the model weights at epoch 5

Other classes in Keras

Besides callbacks, we can also make derived classes in Keras for custom metrics (derived from keras.metrics.Metrics), custom layers (derived from keras.layers.Layer), custom regularizer (derived from keras.regularizers.Regularizer) or even custom models (derived from keras.Model, for such as changing the behavior of invoking a model). All you have to do is to follow the guideline to change the member functions of a class. You must use exactly the same name and parameters in the member functions.

Below is an example from Keras documentation:

class BinaryTruePositives(tf.keras.metrics.Metric):

def __init__(self, name=’binary_true_positives’, **kwargs):
super(BinaryTruePositives, self).__init__(name=name, **kwargs)
self.true_positives = self.add_weight(name=’tp’, initializer=’zeros’)

def update_state(self, y_true, y_pred, sample_weight=None):
y_true = tf.cast(y_true, tf.bool)
y_pred = tf.cast(y_pred, tf.bool)

values = tf.logical_and(tf.equal(y_true, True), tf.equal(y_pred, True))
values = tf.cast(values, self.dtype)
if sample_weight is not None:
sample_weight = tf.cast(sample_weight, self.dtype)
values = tf.multiply(values, sample_weight)
self.true_positives.assign_add(tf.reduce_sum(values))

def result(self):
return self.true_positives

def reset_states(self):
self.true_positives.assign(0)

m = BinaryTruePositives()
m.update_state([0, 1, 1, 1], [0, 1, 0, 0])
print(‘Intermediate result:’, float(m.result()))

m.update_state([1, 1, 1, 1], [0, 1, 1, 0])
print(‘Final result:’, float(m.result()))

This reveals why we would need a class for custom metric: A metric is not just a function but a function that computes its value incrementally, once per batch of training data during the training cycle. Eventually, the result is reported at the result() function at the end of an epoch and reset its memory using reset_state() function so you can start afresh in the next epoch.

For the details on what exactly have to be derived, you should refer to Keras’ documentation.

Further Reading

This section provides more resources on the topic if you are looking to go deeper.

Books

Python Fundamentals, 2018.

Websites

Python classes, https://docs.python.org/3/tutorial/classes.html
Creating custom callback in Keras, https://www.tensorflow.org/guide/keras/custom_callback
Creating custom metrics in Keras, https://keras.io/api/metrics/#creating-custom-metrics
Making new layers and models via subclassing, https://keras.io/guides/making_new_layers_and_models_via_subclassing/

Summary

In this tutorial, you discovered the Python classes and their functionality.

Specifically, you learned:

Why Python classes are important.
How to define and instantiate a class, and set its attributes. 
How to create methods and pass arguments. 
What is class inheritance. 
How to use classes to implement callbacks in Keras. 

Do you have any questions?
Ask your questions in the comments below and I will do my best to answer.



The post Python Classes and Their Use in Keras appeared first on Machine Learning Mastery.

Read MoreMachine Learning Mastery

RELATED ARTICLES

LEAVE A REPLY

Please enter your comment!
Please enter your name here

Most Popular

Recent Comments