Main Content

Custom Training with Multiple GPUs in Experiment Manager

This example shows how to configure multiple parallel workers to collaborate on each trial of a custom training experiment. In this example, parallel workers train on portions of the overall mini-batch in each trial of an image classification experiment. During training, a DataQueue object sends training progress information back to Experiment Manager. If you have a supported GPU, then training happens on the GPU. For more information, see GPU Computing Requirements (Parallel Computing Toolbox).

As an alternative, you can set up a parallel custom training loop that runs a single trial of this experiment programmatically. For more information, see Train Network in Parallel with Custom Training Loop.

Open Experiment

First, open the example. Experiment Manager loads a project with a preconfigured experiment that you can inspect and run. To open the experiment, in the Experiment Browser pane, double-click the name of the experiment (ParallelCustomLoopExperiment).

Custom training experiments consist of a description, a table of hyperparameters, and a training function. For more information, see Configure Custom Training Experiment.

The Description field contains a textual description of the experiment. For this example, the description is:

Use multiple parallel workers to train an image classification network.
Each trial uses a different initial learning rate and momentum.

The Hyperparameters section specifies the hyperparameter values to use for the experiment. When you run the experiment, Experiment Manager trains the network using every combination of hyperparameter values specified in the hyperparameter table. This example uses two hyperparameters:

  • InitialLearnRate sets the initial learning rate used for training. If the learning rate is too low, then training takes a long time. If the learning rate is too high, then training can reach a suboptimal result or diverge. The best learning rate depends on your data as well as the network you are training.

  • Momentum specifies the contribution of the gradient step from the previous iteration to the current iteration of stochastic gradient descent with momentum.

The Training Function specifies the training data, network architecture, training options, and training procedure used by the experiment. The input to the training function is a structure with fields from the hyperparameter table and an experiments.Monitor object that you can use to track the progress of the training, record values of the metrics used by the training, and produce training plots. The training function returns a structure that contains the trained network, the training loss, and the validation accuracy. Experiment Manager saves this output, so you can export it to the MATLAB workspace when the training is complete. The training function has six sections.

  • Initialize Output sets the initial value of the network, training loss, and validation accuracy to empty arrays to indicate that the training has not started.

output.network = [];
output.loss = [];
output.accuracy = [];
  • Load Training and Test Data defines the training and test data for the experiment as imageDatastore objects. The experiment uses the Digits data set, which consists of 5000 28-by-28 pixel grayscale images of digits from 0 to 9, categorized by the digit they represent. For more information on this data set, see Image Data Sets.

monitor.Status = "Loading Data";
digitDatasetPath = fullfile(matlabroot,"toolbox","nnet","nndemos", ...
    "nndatasets","DigitDataset");
imds = imageDatastore(digitDatasetPath, ...
    IncludeSubfolders=true, ...
    LabelSource="foldernames");
[imdsTrain,imdsTest] = splitEachLabel(imds,0.9,"randomized");
classes = categories(imdsTrain.Labels);
numClasses = numel(classes);
XTest = readall(imdsTest);
XTest = cat(4,XTest{:});
XTest = single(XTest) ./ 255;
YTest = imdsTest.Labels;
  • Define Network Architecture defines the architecture for the image classification network. This network architecture includes batch normalization layers that track the mean and variance statistics of the data set. When training in parallel, to ensure the network state reflects the whole mini-batch, combine the statistics from all of the workers at the end of each iteration step. Otherwise, the network state can diverge across the workers. If you are training stateful recurrent neural networks (RNNs), for example, using sequence data that has been split into smaller sequences to train networks containing LSTM or GRU layers, you must also manage the state between the workers. To train the network with a custom training loop and enable automatic differentiation, the training function converts the layer graph to a dlnetwork object.

monitor.Status = "Creating Network";
layers = [
    imageInputLayer([28 28 1],Normalization="none")
    convolution2dLayer(5,20)
    batchNormalizationLayer
    reluLayer
    convolution2dLayer(3,20,Padding=1)
    batchNormalizationLayer
    reluLayer
    convolution2dLayer(3,20,Padding=1
    batchNormalizationLayer
    reluLayer
    fullyConnectedLayer(numClasses)];
lgraph = layerGraph(layers);
net = dlnetwork(lgraph);
  • Set Up Parallel Environment determines if GPUs are available for MATLAB to use. If there are GPUs available, then train on the GPUs. If no parallel pool exists, create one with as many workers as GPUs. If there are no GPUs available, then train on the CPUs. If no parallel pool exists, create one with the default number of workers.

monitor.Status = "Starting Parallel Pool";
pool = gcp("nocreate");
if canUseGPU
    executionEnvironment = "gpu";
    if isempty(pool)
        numberOfGPUs = gpuDeviceCount("available");
        pool = parpool(numberOfGPUs);
    end
else
    executionEnvironment = "cpu";
    if isempty(pool)
        pool = parpool;
    end
end
N = pool.NumWorkers;
  • Specify Training Options defines the training options used by the experiment. In this example, Experiment Manager trains the network with a mini-batch size of 128 for 20 epochs using the initial learning rate and momentum defined in the hyperparameter table. If you are training on a GPU, the mini-batch size scales up linearly with the number of GPUs to keep the workload on each GPU constant. For more information, see Deep Learning with MATLAB on Multiple GPUs.

numEpochs = 20;
miniBatchSize = 128;
velocity = [];
initialLearnRate = params.InitialLearnRate;
momentum = params.Momentum;
decay = 0.01;
if executionEnvironment == "gpu"
    miniBatchSize = miniBatchSize .* N;
end
workerMiniBatchSize = floor(miniBatchSize ./ repmat(N,1,N));
remainder = miniBatchSize - sum(workerMiniBatchSize);
workerMiniBatchSize = workerMiniBatchSize + [ones(1,remainder) zeros(1,N-remainder)];
  • Train Model defines the parallel custom training loop used by the experiment. To execute the code simultaneously on all the workers, the training function uses an spmd block that cannot contain break, continue, or return statements. As a result, you cannot interrupt a trial of the experiment while training is in progress. If you press Stop, Experiment Manager runs the current trial to completion before stopping the experiment. For more information on the parallel custom training loop, see Appendix 1 at the end of this example.

monitor.Metrics = ["TrainingLoss" "ValidationAccuracy"];
monitor.XLabel = "Iteration";
monitor.Status = "Training";
Q = parallel.pool.DataQueue;
updateFcn = @(x) updateTrainingProgress(x,monitor);
afterEach(Q,updateFcn);
spmd
    workerImds = partition(imdsTrain,N,labindex);
    workerImds.ReadSize = workerMiniBatchSize(labindex);
    workerVelocity = velocity;
    iteration = 0;
    lossArray = [];
    accuracyArray = [];
    for epoch = 1:numEpochs
        reset(workerImds);
        workerImds = shuffle(workerImds);
        if ~monitor.Stop
            while gop(@and,hasdata(workerImds))
                iteration = iteration + 1;
                [workerXBatch,workerTBatch] = read(workerImds);
                workerXBatch = cat(4,workerXBatch{:});
                workerNumObservations = numel(workerTBatch.Label);
                workerXBatch =  single(workerXBatch) ./ 255;
                workerY = zeros(numClasses,workerNumObservations,"single");
                for c = 1:numClasses
                    workerY(c,workerTBatch.Label==classes(c)) = 1;
                end
                workerX = dlarray(workerXBatch,"SSCB");
                if executionEnvironment == "gpu"
                    workerX = gpuArray(workerX);
                end
                [workerLoss,workerGradients,workerState] = dlfeval(@modelLoss,net,workerX,workerY);
                workerNormalizationFactor = workerMiniBatchSize(labindex)./miniBatchSize;
                loss = gplus(workerNormalizationFactor*extractdata(workerLoss));
                net.State = aggregateState(workerState,workerNormalizationFactor);
                workerGradients.Value = dlupdate(@aggregateGradients,workerGradients.Value,{workerNormalizationFactor});
                learnRate = initialLearnRate/(1 + decay*iteration);
                [net.Learnables,workerVelocity] = sgdmupdate(net.Learnables,workerGradients,workerVelocity,learnRate,momentum);
            end
            lossArray = [lossArray; iteration, loss];
            if labindex == 1
              YPredScores = predict(net,dlarray(XTest,"SSCB"));
              [~,idx] = max(YPredScores,[],1);
              YPred = classes(idx);
              accuracy = mean(YPred==YTest);
              lossArray = [lossArray; iteration, loss];
              accuracyArray = [accuracyArray; iteration, accuracy];
              data = [numEpochs epoch iteration loss accuracy];
              send(Q,gather(data));
            end
        end
    end
end

To inspect the training function, under Training Function, click Edit. The training function opens in MATLAB® Editor. In addition, the code for the training function appears in Appendix 1 at the end of this example.

Run Experiment

When you run the experiment, Experiment Manager trains the network defined by the training function multiple times. Each trial uses a different combination of hyperparameter values.

Because this experiment uses the parallel pool for this MATLAB session, you cannot train multiple trials at the same time. On the Experiment Manager toolstrip, under Mode, select Sequential and click Run. Alternatively, to offload the experiment as a batch job, set Mode to Batch Sequential, specify your Cluster and Pool Size, and click Run. For more information, see Offload Experiments as Batch Jobs to Cluster.

A table of results displays the training loss and validation accuracy for each trial.

To display the training plot and track the progress of each trial while the experiment is running, under Review Results, click Training Plot.

Note that the training function for this experiment uses an spmd statement, which cannot contain break, continue, or return statements. As a result, you cannot interrupt a trial of the experiment while training is in progress. If you click Stop, Experiment Manager runs the current trial to completion before stopping the experiment.

Evaluate Results

To find the best result for your experiment, sort the table of results by validation accuracy.

  1. Point to the ValidationAccuracy column.

  2. Click the triangle icon.

  3. Select Sort in Descending Order.

The trial with the highest validation accuracy appears at the top of the results table.

To test the best trial in your experiment, plot a confusion matrix.

  1. In the results table, select the trial with the highest validation accuracy.

  2. On the Experiment Manager toolstrip, click Export > Training Output.

  3. In the dialog window, enter the name of a workspace variable for the exported training output. The default name is trainingOutput.

  4. Create a confusion matrix by calling the drawConfusionMatrix function, which is listed in Appendix 2 at the end of this example. As the input to the function, use the exported training output and the fraction of the Digits data set to use as a test set. For instance, in the MATLAB Command Window, enter:

drawConfusionMatrix(trainingOutput,0.5)

The function creates a confusion matrix using half of the images in the data set.

To record observations about the results of your experiment, add an annotation.

  1. In the results table, right-click the ValidationAccuracy cell of the best trial.

  2. Select Add Annotation.

  3. In the Annotations pane, enter your observations in the text box.

For more information, see Sort, Filter, and Annotate Experiment Results.

Close Experiment

In the Experiment Browser pane, right-click the name of the project and select Close Project. Experiment Manager closes all of the experiments and results contained in the project.

Appendix 1: Training Function

This function configures the training data, network architecture, and training options for the experiment. To execute the code simultaneously on all the workers, the function uses an spmd block. Within the spmd block, labindex gives the index of the worker currently executing the code. Before training, the function partitions the datastore for each worker by using the partition function, and sets ReadSize to the mini-batch size of the worker. For each epoch, the function resets and shuffles the datastore. For each iteration in the epoch, the function:

  • Reads a mini-batch from the datastore and process the data for training.

  • Computes the loss and the gradients of the network on each worker by calling dlfeval on the modelLoss function.

  • Obtains the overall loss using cross-entropy and aggregates the losses on all workers using the sum of all losses.

  • Aggregates and updates the gradients of all workers using the dlupdate function with the aggregateGradients function.

  • Aggregates the state of the network on all workers using the aggregateState function.

  • Updates the network learnable parameters with the sgdmupdate function.

At the end of each epoch, the function uses only worker to send the training progress information back to the client.

Input

  • params is a structure with fields from the Experiment Manager hyperparameter table.

  • monitor is an experiments.Monitor object that you can use to track the progress of the training, update information fields in the results table, record values of the metrics used by the training, and produce training plots.

Output

  • output is a structure that contains the trained dlnetwork object, the training loss array, and the validation accuracy array. Experiment Manager saves this output, so you can export it to the MATLAB workspace when the training is complete.

function output = ParallelCustomLoopExperiment_training(params,monitor)

output.network = [];
output.loss = [];
output.accuracy = [];

monitor.Status = "Loading Data";
digitDatasetPath = fullfile(matlabroot,"toolbox","nnet","nndemos", ...
    "nndatasets","DigitDataset");
imds = imageDatastore(digitDatasetPath, ...
    IncludeSubfolders=true, ...
    LabelSource="foldernames");
[imdsTrain,imdsTest] = splitEachLabel(imds,0.9,"randomized");

classes = categories(imdsTrain.Labels);
numClasses = numel(classes);

XTest = readall(imdsTest);
XTest = cat(4,XTest{:});
XTest = single(XTest) ./ 255;
YTest = imdsTest.Labels;

monitor.Status = "Creating Network";

layers = [
    imageInputLayer([28 28 1],Normalization="none")
    convolution2dLayer(5,20)
    batchNormalizationLayer
    reluLayer
    convolution2dLayer(3,20,Padding=1)
    batchNormalizationLayer
    reluLayer
    convolution2dLayer(3,20,Padding=1)
    batchNormalizationLayer
    reluLayer
    fullyConnectedLayer(numClasses)];

lgraph = layerGraph(layers);
net = dlnetwork(lgraph);

monitor.Status = "Starting Parallel Pool";

pool = gcp("nocreate");

if canUseGPU
    executionEnvironment = "gpu";
    if isempty(pool)
        numberOfGPUs = gpuDeviceCount("available");
        pool = parpool(numberOfGPUs);
    end
else
    executionEnvironment = "cpu";
    if isempty(pool)
        pool = parpool;
    end
end

N = pool.NumWorkers;

numEpochs = 20;
miniBatchSize = 128;
velocity = [];
initialLearnRate = params.InitialLearnRate;
momentum = params.Momentum;
decay = 0.01;

if executionEnvironment == "gpu"
    miniBatchSize = miniBatchSize .* N;
end

workerMiniBatchSize = floor(miniBatchSize ./ repmat(N,1,N));
remainder = miniBatchSize - sum(workerMiniBatchSize);
workerMiniBatchSize = workerMiniBatchSize + [ones(1,remainder) zeros(1,N-remainder)];

monitor.Metrics = ["TrainingLoss" "ValidationAccuracy"];
monitor.XLabel = "Iteration";
monitor.Status = "Training";

Q = parallel.pool.DataQueue;
updateFcn = @(x) updateTrainingProgress(x,monitor);
afterEach(Q,updateFcn);

spmd
    workerImds = partition(imdsTrain,N,labindex);
    workerImds.ReadSize = workerMiniBatchSize(labindex);
    
    workerVelocity = velocity;
   
    iteration = 0;
    lossArray = [];
    accuracyArray = [];
    
    for epoch = 1:numEpochs
        reset(workerImds);
        workerImds = shuffle(workerImds);
        
        if ~monitor.Stop
            while gop(@and,hasdata(workerImds))
                iteration = iteration + 1;
                
                [workerXBatch,workerTBatch] = read(workerImds);
                workerXBatch = cat(4,workerXBatch{:});
                workerNumObservations = numel(workerTBatch.Label);
    
                workerXBatch =  single(workerXBatch) ./ 255;
                
                workerY = zeros(numClasses,workerNumObservations,"single");
                for c = 1:numClasses
                    workerY(c,workerTBatch.Label==classes(c)) = 1;
                end
                
                workerX = dlarray(workerXBatch,"SSCB");
                
                if executionEnvironment == "gpu"
                    workerX = gpuArray(workerX);
                end
                
                [workerLoss,workerGradients,workerState] = dlfeval(@modelLoss,net,workerX,workerY);
                
                workerNormalizationFactor = workerMiniBatchSize(labindex)./miniBatchSize;
                loss = gplus(workerNormalizationFactor*extractdata(workerLoss));
                
                net.State = aggregateState(workerState,workerNormalizationFactor);
                
                workerGradients.Value = dlupdate(@aggregateGradients,workerGradients.Value,{workerNormalizationFactor});
                
                learnRate = initialLearnRate/(1 + decay*iteration);
                
                [net.Learnables,workerVelocity] = sgdmupdate(net.Learnables,workerGradients,workerVelocity,learnRate,momentum);
            end             
            
            if labindex == 1
                YPredScores = predict(net,dlarray(XTest,"SSCB"));
                [~,idx] = max(YPredScores,[],1);
                YPred = classes(idx);
                accuracy = mean(YPred==YTest);
                
                lossArray = [lossArray; iteration, loss];
                accuracyArray = [accuracyArray; iteration, accuracy];
                
                data = [numEpochs epoch iteration loss accuracy];
                send(Q,gather(data)); 
            end  
        end
    end
end

output.network = net{1};
output.loss = lossArray{1};
output.accuracy = accuracyArray{1};

delete(gcp("nocreate"));
end

Appendix 2: Custom Training Helper Functions

This function computes the loss and the gradients of the loss with respect to the learnable parameters of the network. This function computes the network outputs for a mini-batch X with forward and softmax and calculates the loss, given the true outputs, using cross-entropy. When you call this function with dlfeval, automatic differentiation is enabled, and dlgradient can compute the loss with respect to the learnables automatically.

function [loss,gradients,state] = modelLoss(net,X,Y)
[YPred,state] = forward(net,X);
YPred = softmax(YPred);
loss = crossentropy(YPred,Y);
gradients = dlgradient(loss,net.Learnables);
end

This function displays training progress information and updates metric values that come from the workers. The DataQueue object in this example calls this function every time a worker sends data.

function updateTrainingProgress(data,monitor)
monitor.Progress = (data(2)/data(1))*100;
recordMetrics(monitor,data(4), ...
    TrainingLoss=data(3));
end

This function aggregates the gradients on all workers by adding them together. gplus adds together and replicates all the gradients on the workers. Before adding them together, the function normalizes them by multiplying them by a factor that represents the proportion of the overall mini-batch that the worker is working on.

function gradients = aggregateGradients(gradients,factor)
gradients = gplus(factor*gradients);
end

This function aggregates the network state on all workers. The network state contains the trained batch normalization statistics of the data set. Since each worker sees only a portion of the mini-batch, aggregate the network state so that the statistics are representative of the statistics across all the data. For each mini-batch, the combined mean is calculated as a weighted average of the mean across the workers for each iteration. The combined variance is calculated according to the formula:

$$s_c^2 = \frac{1}{M} \sum_{j=1}^{N}m_j[s_j^2 + (\bar{x_j} -
\bar{x_c})^2]$$

$N$ is the total number of workers, $M$ is the total number of observations in a mini-batch, $m_j$ is the number of observations processed on the $j$ th worker, $\bar{x}_j$ and $s_j^2$ are the mean and variance statistics calculated on that worker, and $\bar{x}_c$ is the combined mean across all workers.

function state = aggregateState(state,factor)

    numrows = size(state,1);
    
    for j = 1:numrows
        isBatchNormalizationState = state.Parameter(j) =="TrainedMean"...
            && state.Parameter(j+1) =="TrainedVariance"...
            && state.Layer(j) == state.Layer(j+1);
        
        if isBatchNormalizationState
            meanVal = state.Value{j};
            varVal = state.Value{j+1};
            
            combinedMean = gplus(factor*meanVal);
                   
            combinedVarTerm = factor.*(varVal + (meanVal - combinedMean).^2);        
            
            state.Value(j) = {combinedMean};
            state.Value(j+1) = {gplus(combinedVarTerm)};
           
        end
    end
end

Appendix 3: Create Confusion Matrix

This function takes as input a trained network and the fraction of the Digits data set to use as a test set and creates a confusion matrix chart.

function drawConfusionMatrix(trainingOutput,testSize)

dataFolder = fullfile(toolboxdir("nnet"), ...
    "nndemos","nndatasets","DigitDataset");
imds = imageDatastore(dataFolder, ...
    IncludeSubfolders=true, ....
    LabelSource="foldernames");
[~,imdsTest] = splitEachLabel(imds,testSize,"randomized");

classes = categories(imdsTest.Labels);
trainedNet = trainingOutput.network;

XTest = readall(imdsTest);
XTest = cat(4,XTest{:});
XTest = single(XTest) ./ 255;
trueLabels = imdsTest.Labels;

YPredScores = predict(trainedNet,dlarray(XTest,"SSCB"));
[~,idx] = max(YPredScores,[],1);
predictedLabels = classes(idx);

figure
confusionchart(trueLabels,categorical(predictedLabels), ...
    ColumnSummary="column-normalized", ...
    RowSummary="row-normalized", ...
    Title="Confusion Matrix for Digits Data Set");
cm = gcf;
cm.Position(3) = cm.Position(3)*1.5;

end

See Also

Apps

Objects

Related Topics