Pretrained Models for Transfer Learning in Keras for Computer Vision
Tensorflow is one of the highly used libraries for Machine Learning. It has built-in support for Keras. We can easily call functions related to Keras by using the tf.keras
module. Computer Vision is one of the most interesting branches of machine learning. The ImageNet dataset was the turning point for researchers related to Computer Vision as it provided a large set of images for Object detection. It is now a benchmark for testing the accuracy of Image Classification and Object Detection deep learning models.
Transfer Learning is also one of the major developments in the case of Deep Learning for Object Detection. In transfer learning, we take a pre-trained model performing classification on a dataset and then apply this same model to another set of classification task by just optimising the hyperparameters a little bit. Transfer Learning has two benefits:
- It requires less time to train the model as it is already trained on a different task
- It can be used for tasks which have smaller dataset as the model is already trained on a larger dataset and the weights are transferred to the new task
Illustration of Transfer Learning where the model trained for object detection like Cat,Dog,etc. is used again for Cancer Detection by transferring of weights.
The Tensorflow Keras module has a lot of pretrained models which can be used for transfer learning. The details about which can be found here. The tf.keras.applications
module contains these models.
A list of modules and functions for calling Deep learning model architectures present in the tf.keras.applications
module is given below:
Module | DL Model Functions |
densenet | DenseNet121(), DenseNet169(), DenseNet201() |
efficientnet | EfficientNetB0(), EfficientNetB1(), EfficientNetB2(), EfficientNetB3(), EfficientNetB4(), EfficientNetB5(), EfficientNetB6(), EfficientNetB7() |
inception_resnet_v2 | InceptionResNetV2() |
inception_v3 | InceptionV3() |
mobilenet | MobileNet() |
mobilenet_v2 | MobileNetV2() |
nasnet | NASNetLarge(), NASNetMobile() |
resnet | ResNet101(), ResNet152(), |
resnet50 | ResNet50() |
resnet_v2 | ResNet101V2(), ResNet152V2(), ResNet50V2() |
vgg16 | VGG16() |
vgg19 | VGG19() |
xception | Xception() |
We write models in TensorFlow as per the example given below:
import tensorflow.keras as keras
model = keras.Sequential([
# First Convolutional Block
layers.Conv2D(filters=32, kernel_size=5, activation="relu", padding='same',input_shape=[128, 128, 3]),
layers.MaxPool2D(),
# Second Convolutional Block
layers.Conv2D(filters=64, kernel_size=3, activation="relu", padding='same'),
layers.MaxPool2D(),
# Third Convolutional Block
layers.Conv2D(filters=128, kernel_size=3, activation="relu", padding='same'),
layers.MaxPool2D(),
# Classifier Head
layers.Flatten(),
layers.Dense(units=6, activation="relu"),
layers.Dense(units=1, activation="sigmoid"),
])
The structure of this Deep Learning model is as follow
In the same way we can call the Xception()
function from the tf.keras.applications
module to add the pretrained model to our architecture, this model is pretrained so we are taking the weights from previous dataset or task 'imagenet' and in our model not training it again, hence the parameter trainable
is set to False. A globalaveragepooling layer is used and then softmax
is used for multiclass classification in case of binary classification the activation function must be sigmoid
.
pretrained_model = tf.keras.applications.Xception(
weights='imagenet',
include_top=False ,
input_shape=[*IMAGE_SIZE, 3]
)
pretrained_model.trainable = False
model = tf.keras.Sequential([
pretrained_model,
tf.keras.layers.GlobalAveragePooling2D(),
tf.keras.layers.Dense(len(CLASSES), activation='softmax')
])
We can use all the different models in the same way by just changing the functions.