Go Back

Faster Machine Learning through Transfer Learning

Introduction

According to Andrew NG, professor at Stanford, transfer learning will be, after supervised learning, the next driver of Machine Learning commercial success. Transfer learning enables companies to use neural networks that have been trained on one task in other areas. For example, information gained from ImageNet (a generic image database) can be used for interpreting chest x-rays, identifying eye diseases, and detecting Alzheimer’s. If trained from scratch, complex neural networks can take weeks/days to train. With transfer learning, there is no need to train the whole network; specific parts of it can be trained. By doing so, transfer learning dramatically reduces time to market and data needed to train a network. The goal of this project was to use a pre-trained model, change the final layer, and apply it to flower detection.

Executive Summary

The project started with the transformation of images. Image transformation, also called data augmentation, is fundamental for training neural networks. Once it was done, a classifier was chosen to train the network. The decision was made to start with densenet121. The training had 10 cycles, so-called epochs. The accuracy is 90%+ for both validation and training sets. It became apparent that neural networks trained on older software versions cannot be used with newer software versions. This issue caused a retraining of the neural network. The whole source code was published on Github. The project is finished and can be used for training and predictions on any machine. No code refactoring or optimization is planned.

Results

The decision was made to test the simplest setup possible. The architecture of the whole neural network consists of only one hidden layer. As a starting point, the learning-rate was set to 0.003, batchsize to 64, and number of epochs to 10. As a pre-trained model, densenet121 was chosen. It took 1 hour and 2 minutes to train the network. The accuracy on the testing set is 91.4%. This means that, out of 100 predictions, 91 predictions were made correctly.

Project Overview

The goal of this project was to identify different species of flowers. For training, validation, and prediction, the 102 Category Flower Dataset was used. The images were already classified into train/valid/test folders. The first part of the project was about the preprocessing of images. In order to get high accuracy of the model, transformation on the images was done. In particular, random resizing and horizontal flipping were used on the training data. For validation and testing data, additional resizing and cropping were necessary.

The second part of the project was all about building and training the classifier. In transfer learning, often only the final layer gets changed. To do so, first the pre-trained network has to be chosen. The decision was made to go with densenet121. After loading the inital pre-trained model, a classifier was assembled. The classifier consists of only one hidden layer, with ReLU as the activation function. A LogSoftmax was used as the activation function in the output layer. Dropout was used as a penalization method to prevent overfitting. The usage of Dropout causes a lower loss in the validation set in comparison to the train loss (see Figure 1.0).

transfer learning progress

Figure 1.0: Running Metrics Loss / Accuracy

As it shows in figure 1.0, the validation loss at the end of the 10th epoch is around 0.287 and the validation accuracy is around 0.922. It took 1 hour and 22 minutes to complete the training. When the training was finished, another test on a testing set was run. The accuracy on the testing set was around 0.914.

The final part of the project was about saving the trained network and using it for inference (prediction). During this part, one major issue was discovered. The network was trained on a machine with version 0.4.0 of PyTorch. The predictions were to be made on a machine with version 1.4.0 of PyTorch. It turned out that networks trained on version 0.x.x are not compatible with 1.x.x versions. Therefore, the networks were retrained on a machine with a version of 1.3.0, which did not cause any issues for inference.

transfer learning results

Figure 1.1: Sanity Checking

Figure 1.1 shows an example with a picture and a probability distribution for 5 classes. There is a probability of 99.36% that the picture is of a category pink pimrose.

After sanity checking, the whole project was packaged into a command line application. The training, as well as the inference part, can be run independently from the command line on any machine. Pre-requisites and source code are available here.

Conclusion

Transfer learning enables companies, developers, and data scientists to accelerate the AI training workflow. The main advantage of this approach is that small datasets can be used in order to build robust models. As with this project, in only 1 hour, it was possible to achieve accuracy of 91.4%. Transfer learning saves time, resources, and money.

Recommendations

There are still some low-hanging fruits to get even better accuracy. There are 10+ different pre-trained models that can be used for testing and fine-tuning. Also, hyperparameters such as the learning rate, batchsize, dropout probability, and epochs can be futher optimized. The network architecture can be extended to some additional layers. Since the goal of the project was to evaluate the transfer learning approach in general, there hasn’t been any focus on doing those optimizations.

The whole codebase was rewritten from a Jupyter Notebook. In the current version, each part of the pipeline can run independently. The code is currently not covered with unit tests. Since the project was made during the Machine Learning program of Udacity, there are no plans for refactoring.

Transfer learning is already advanced in image recognition, natural language processing, and speech recognition. It’s instead recommended to research the possibilities of starting from scratch.

Source Code

References