Main Content

Resume Training from Checkpoint Network

This example shows how to save checkpoint networks while training a deep learning network and resume training from a previously saved network.

Load Sample Data

Load the sample data as a 4-D array. digitTrain4DArrayData loads the digit training set as 4-D array data. XTrain is a 28-by-28-by-1-by-5000 array, where 28 is the height and 28 is the width of the images. 1 is the number of channels and 5000 is the number of synthetic images of handwritten digits. YTrain is a categorical vector containing the labels for each observation.

[XTrain,YTrain] = digitTrain4DArrayData;
size(XTrain)
ans = 1×4

          28          28           1        5000

Display some of the images in XTrain.

figure;
perm = randperm(size(XTrain,4),20);
for i = 1:20
    subplot(4,5,i);
    imshow(XTrain(:,:,:,perm(i)));
end

Define Network Architecture

Define the neural network architecture.

layers = [
    imageInputLayer([28 28 1])
    
    convolution2dLayer(3,8,'Padding','same')
    batchNormalizationLayer
    reluLayer    
    maxPooling2dLayer(2,'Stride',2) 
    
    convolution2dLayer(3,16,'Padding','same')
    batchNormalizationLayer
    reluLayer    
    maxPooling2dLayer(2,'Stride',2)
    
    convolution2dLayer(3,32,'Padding','same')
    batchNormalizationLayer
    reluLayer   
    averagePooling2dLayer(7)  
    
    fullyConnectedLayer(10)
    softmaxLayer
    classificationLayer];

Specify Training Options and Train Network

Specify training options for stochastic gradient descent with momentum (SGDM) and specify the path for saving the checkpoint networks.

checkpointPath = pwd;
options = trainingOptions('sgdm', ...
    'InitialLearnRate',0.1, ...
    'MaxEpochs',20, ...
    'Verbose',false, ...
    'Plots','training-progress', ...
    'Shuffle','every-epoch', ...
    'CheckpointPath',checkpointPath);

Train the network. trainNetwork uses a GPU if there is one available. If there is no available GPU, then it uses CPU. trainNetwork saves one checkpoint network each epoch and automatically assigns unique names to the checkpoint files.

net1 = trainNetwork(XTrain,YTrain,layers,options);

Load Checkpoint Network and Resume Training

Suppose that training was interrupted and did not complete. Rather than restarting the training from the beginning, you can load the last checkpoint network and resume training from that point. trainNetwork saves the checkpoint files with file names on the form net_checkpoint__195__2018_07_13__11_59_10.mat, where 195 is the iteration number, 2018_07_13 is the date, and 11_59_10 is the time trainNetwork saved the network. The checkpoint network has the variable name net.

Load the checkpoint network into the workspace.

load('net_checkpoint__195__2018_07_13__11_59_10.mat','net')

Specify the training options and reduce the maximum number of epochs. You can also adjust other training options, such as the initial learning rate.

options = trainingOptions('sgdm', ...
    'InitialLearnRate',0.1, ...
    'MaxEpochs',15, ...
    'Verbose',false, ...
    'Plots','training-progress', ...
    'Shuffle','every-epoch', ...
    'CheckpointPath',checkpointPath);

Resume training using the layers of the checkpoint network you loaded with the new training options. If the checkpoint network is a DAG network, then use layerGraph(net) as the argument instead of net.Layers.

net2 = trainNetwork(XTrain,YTrain,net.Layers,options);

See Also

| |

Related Examples

More About