Use pretrained network
- use
imread
function to import images:I = imread('filename.png');
imshow
to display an image:imshow(I)
Make prediction
- use
alexnet
function to create a copy of the predefined deep network “AlexNet” in the MATLAB workspace:net = alexnet
- use
classify
function to make a prediction on an image:pred = classify(net,img)
CNN Architecture
- use
Layers
property to inspect the layers of the network:ly = net.Layers
- The
Layers
variable is an array of network layers. use array indexing to inspect an individual layer:layer3 = ly(3)
- Each layer of the network has properties relevant to that type of layer. An important property for an input layer is InputSize, which is the size (dimensions) of images the network expects as input:
insz = inlayer.InputSize
- AlexNet requires an input image of size 227-by-227-by-3, meaning a color image 227 pixels high and 227 pixels wide.
- The
Classes
property of an output layer gives the names of the categories the network is trained to predict:categorynames = outlayer.Classes
Investigating Predictions
classify
function gives the class to which the network assigns the highest score.- We can obtain the predicted scores for all the classes by requesting a second output from classify:
[pred,scrs] = classify(net,img)
- need to set threshold for the final prediction propability to have a better view of the prediction. e.g. set the threshold of one standard deviation above the median score:
thresh = median(scores) + std(scores);
Managing Collections of Data
Image Datastore
- use the
imageDatastore
function to create a datastore in MATLAB for later use of image classification:ds = imageDatastore('foo*.png')
- use the
read
,readimage
, andreadall
functions to read images:read
imports images one at a time, in order;readimage
imports a single specific image:I = readimage(ds,n)
readall
imports all the images into a single cell array (with each image in a separate cell).
- can use an image datastore in place of an individual image in CNN functions such as classify:
preds = classify(net,ds)
. The result will be an array of predicted classes, one for each image in the datastore.
Prepare Images
- can use the
InputSize
property of input layer to see what is the expected image size:expectedSize = inlayer.InputSize
- use
imresize
function to resize an image to match the expected input size:imgresz = imresize(img,[numrows numcols]);
Preprocessing Images in a Datastore
- to perform the same preprocessing steps on the entire data set, which is more efficient
- use
augmentedImageDatastore
function to perform basic preprocessing:auds = augmentedImageDatastore([r c],imds)
, r, c are the expected image size - use
montage
function to display the images in the datastore:montage(imds)
- use ‘ColorPreprocessing’ option to convert these images to a 3-D array:
auds = augmentedImageDatastore([n m],imds,'ColorPreprocessing','gray2rgb')
- This will replicate the grayscale image three times to create a 3-D array.
Create Datasore Uisng Subfolders
- use the
IncludeSubfolders
option to look for images within subfolders of the given folder:ds = imageDatastore('folder','IncludeSubfolders',true)
Performing Transfer Learning
The benefits of transfer learning
- It is extremely easy to get started using a pretrained network like AlexNet. But you have no flexibility in the way the network operates, and the network probably won’t solve the exact problem you are trying to solve.
- You can build and train a network yourself, starting with just the network architecture and random weights. But achieving reasonable results requires a lot of effort: (1) knowledge and experience with network architecture, (2) a huge amount of training data, and (3) a lot of computer time.
- Transfer learning is an efficient solution for many problems. Training requires some data and computer time, but much less than training from scratch, and the result is a network suited to your specific problem.
Components of transfer learning
- Network layers of the pretrained network, which serves as the starting point
- Training data
- Algorithms options: batch size, max iteration, learning rate
Prepare Training Data
- Importing labels:
- The labels needed for training can be stored in the
Labels
property of the image datastore. By default, theLabels
property is empty. - We can have the datastore automatically determine the labels from the folder names by specifying the
LabelSource
option:ds = imageDatastore(folder,'IncludeSubfolders',true,'LabelSource','foldernames')
- The labels needed for training can be stored in the
- Split the data:
- use the
splitEachLabel
function to divide the images in a datastore into two separate datastores:[ds1,ds2] = splitEachLabel(imds,p)
, p is the proportion of images with each label fromimds
that should be contained inds1
. - By default,
splitEachLabel
keeps the files in order. We can randomly shuffle the files by adding the optionalrandomized
flag:[ds1,ds2] = splitEachLabel(imds,p,'randomized')
- To avoid the unbalanced classes in training dataset, we can also specify an exact number of files to take from each label to assign to ds1:
[ds1,ds2] = splitEachLabel(imds,n)
. - This ensures that every label in
ds1
hasn
images, even if the categories do not all contain the same number of images. - We can also split your data into three sets: training, validation during training, and testing. This can be made by specifying multiple values of p or n as inputs, and ask for the appropriate number of datastores as outputs.
- use the
- Augmented traning data:
- use
augmentedImageDatastore
function to add the variaty to the images - use
imageDataAugmenter
function to set the tranformation to the images: rotation, reflection, translation, shear, scaling
- use
Modify Network Layers
- Normally transfer learnig only need to change the last few layers, which means they apply the same feature extraction but output different classes
- To modify a preexisting network, you create a new layer. Then index into the layer array that represents the network and overwrite the chosen layer with the newly created layer:
fc= fullyConnectedLayer(12); layers(23) = fc;
- The
fullyConnectedLayer
function creates a new fully connected layer, with a given number of neurons:fclayer = fullyConnectedLayer(n)
- Also the output layer still uses the previous labels in the pretrained network. We need to replace the output layer with a new, blank output layer.
- use the
classificationLayer
function to create a new output layer for an image classification network:cl = classificationLayer
Set Training Options
- use the
trainingOptions
function to see the available options for the training algorithm:opts = trainingOptions('sgdm')
. This creates a variableopts
that contains the default options for the training algorithm, “stochastic gradient descent with momentum”. - when performing transfer learning, you will typically want to start with the
InitialLearnRate
set to a smaller value than the default of 0.01:opts = trainingOptions('sgdm','InitialLearnRate',0.001)
Train the Network
- “Mini-batch”
- At each iteration, a subset of the training images, known as a mini-batch, is used to update the weights. Each iteration uses a different mini-batch. Once the whole training set has been used, that’s known as an epoch.
- The maximum number of epochs (
MaxEpochs
) and the size of the mini-batches (MiniBatchSize
) are parameters you can set in the training algorithm options. - Note that the loss and accuracy reported during training are for the mini-batch being used in the current iteration.
- By default, the images are shuffled once prior to being divided into mini-batches. You can control this behavior with the
Shuffle
option.
- Using GPUs
- If have an appropriate GPU and Parallel Computing Toolbox installed, the trainNetwork function will automatically perform the training on the GPU – no special coding required.
- If not, the training will be done on your computer’s CPU instead.
- Transfer Learning Example Script:
1 | # Get training images |
Evaluate Performance
- The fields
TrainingLoss
andTrainingAccuracy
ininfo
variable above contain a record of the performance of the network on the training data at each iteration. - Determine how many of the test images the network correctly classified by comparing the predicted classification with the known classification. The known classifications are stored in the
Labels
property of the datastore:flwrActual = testImgs.Labels; numCorrect = nnz(flwrPreds == flwrActual)
- Investigate Performance by class:
- The
confusionchart
function calculates and displays the confusion matrix for the predicted classifications:confusionchart(knownclass,predictedclass)
- The (j,k) element of the confusion matrix is a count of how many images from class j the network predicted to be in class k. Hence, diagonal elements represent correct classifications; off-diagonal elements represent misclassifications.
- The
Summary
Create a model
Function | Description |
---|---|
alexnet |
Load pretrained network “AlexNet” |
supported networks | a list of available pretrained networks |
fullyConnectedLayer |
Create new fully connected network layer |
classificationLayer |
Create new output layer for a classification network |
Get training images
Function | Description |
---|---|
imageDatastore |
Create datastore reference to image files |
augmentedImageDatastore |
Preprocess a collection of image files |
splitEachLabel |
Divide datastore into multiple datastores |
Set training algorithm options
Function | Description |
---|---|
trainingOptions |
Create variable containing training algorithm options |
Perform training
Function | Description |
---|---|
trainNetwork |
Perform training |
Use trained network to perform classifications
Function | Description |
---|---|
classify |
Obtain trained network’s classifications of input images |
Evaluate trained network
Function | Description |
---|---|
nnz |
Count non-zero elements in an array |
confusionchart |
Calculate confusion matrix |
heatmap |
Visualize confusion matrix as a heatmap |
Some resources:
- Mathwork blogs for deep learning https://blogs.mathworks.com/deep-learning/
- Deep learning toolbox: https://au.mathworks.com/help/deeplearning/index.html