Quantize Multiple-Input Network Using Image and Feature Data
This example shows how to quantize a network with multiple inputs. The network classifies handwritten digits using both image and feature input data. To learn more about multi-input networks, see Multiple-Input and Multiple-Output Networks.
Load Training Data
Load the training data. The digitTrain4DArrayData function loads the images, labels, and clockwise rotation angles of the digits data set as numeric arrays. To learn more about the digits data set used in this example, see Data Sets for Deep Learning.
[X1Train,TTrain,X2Train] = digitTrain4DArrayData;
To train the network using both the image and feature data, create a single datastore that contains the training predictors and responses. Convert the numeric arrays to datastores using arrayDatastore. Use the combine function to combine the datastores into a single datastore.
dsX1Train = arrayDatastore(X1Train,IterationDimension=4); dsX2Train = arrayDatastore(X2Train); dsTTrain = arrayDatastore(TTrain); dsTrain = combine(dsX1Train,dsX2Train,dsTTrain); classes = categories(TTrain);
Specify Training Options
Specify the training options.
Train using the SGDM optimizer.
Train for 15 epochs.
Train with a learning rate of 0.01.
Display the training progress in a plot.
Suppress the verbose output.
options = trainingOptions("sgdm", ... MaxEpochs=15, ... InitialLearnRate=0.01, ... Plots="training-progress", ... Verbose=0);
Train Network
Train the network using the trainDigitsNetwork function. To learn more about how to define the network architecture, see Train Network on Image and Feature Data.
net = trainDigitsNetwork(dsTrain,classes,options)

net =
dlnetwork with properties:
Layers: [10×1 nnet.cnn.layer.Layer]
Connections: [9×2 table]
Learnables: [8×3 table]
State: [2×3 table]
InputNames: {'imageinput' 'features'}
OutputNames: {'softmax'}
Initialized: 1
View summary with summary.
Test Network
Test the classification accuracy of the network by comparing the predictions on a test set of data with the true labels. Load the test data and create a combined datastore containing the images and features.
[X1Test,TTest,X2Test] = digitTest4DArrayData; dsX1Test = arrayDatastore(X1Test,IterationDimension=4); dsX2Test = arrayDatastore(X2Test); dsTTest = arrayDatastore(TTest); dsTest = combine(dsX1Test,dsX2Test,dsTTest);
Create a minibatchqueue object to create minibatches to preprocess the data for dlnetwork prediction.
mbqTest = minibatchqueue(dsTest,... MiniBatchSize=32,... MiniBatchFcn=@preprocessMiniBatchTraining, ... OutputAsDlarray=[1 1 1], ... OutputEnvironment=["auto","auto","auto"], ... PartialMiniBatch="return", ... MiniBatchFormat=["SSCB","BC",""]);
Use the modelAccuracy function to evaluate the accuracy of the network on the test data set.
accuracyOriginal = modelAccuracy(net,mbqTest,classes,dsTest.numpartitions)
accuracyOriginal = 98.4600
Use the modelPredictions function to compute the predicted classes. Visualize the predictions using a confusionchart.
YTest = modelPredictions(net,mbqTest,classes); figure confusionchart(TTest,YTest)

Evaluate the classification accuracy based on the model predictions.
accuracy = mean(YTest == TTest)
accuracy = 0.9846
To observe the classification results, view some of the images with their prediction labels.
idx = randperm(size(X1Test,4),9); figure tiledlayout(3,3) for i = 1:9 nexttile I = X1Test(:,:,:,idx(i)); imshow(I) label = string(YTest(idx(i))); title("Predicted Label: " + label) end

Quantize Network
To quantize a network with multiple inputs, the input data for the calibrate and validate functions must be a combinedDatastore or a transformedDatastore.
For validation, the datastore must output a cell array with (numInputs+1) columns, where numInputs is the number of inputs to the network. In this case, the first numInputs columns specify the predictors for each input and the last column specifies the responses.
Create calibration and validation data stores using random data from the test data set.
randomImagesCalibration = randperm(4999); calibrationDataStore = dsTest.subset(randomImagesCalibration(1:200)); randomImagesValidation = randperm(4999); validationDataStore = dsTest.subset(randomImagesValidation(1:100));
Create a dlquantizer object and specify the network to quantize. When you use the MATLAB execution environment, quantization is performed using the fi fixed-point data type which requires a Fixed-Point Designer™ license.
quantObj = dlquantizer(net,ExecutionEnvironment="MATLAB"); Use the calibrate function to exercise the network with the calibration data and collect range statistics for the weights, biases, and activations at each layer.
calResults = calibrate(quantObj,calibrationDataStore)
calResults=16×5 table
Optimized Layer Name Network Layer Name Learnables / Activations MinValue MaxValue
____________________ __________________ ________________________ ___________ __________
{'conv_Weights'} {'conv' } "Weights" -0.28447 0.36445
{'conv_Bias' } {'conv' } "Bias" -8.5358e-07 1.2699e-06
{'fc_1_Weights'} {'fc_1' } "Weights" -0.084955 0.077845
{'fc_1_Bias' } {'fc_1' } "Bias" -0.014489 0.016811
{'fc_2_Weights'} {'fc_2' } "Weights" -0.45607 0.40908
{'fc_2_Bias' } {'fc_2' } "Bias" -0.020831 0.020135
{'imageinput' } {'imageinput'} "Activations" 0 1
{'features' } {'features' } "Activations" -45 45
{'conv' } {'conv' } "Activations" -1.8417 1.1134
{'batchnorm' } {'batchnorm' } "Activations" -9.5983 10.389
{'relu' } {'relu' } "Activations" 0 10.389
{'fc_1' } {'fc_1' } "Activations" -13.472 14.063
{'flatten' } {'flatten' } "Activations" -13.472 14.063
{'cat' } {'cat' } "Activations" -45 45
{'fc_2' } {'fc_2' } "Activations" -38.1 36.679
{'softmax' } {'softmax' } "Activations" 4.1264e-31 1
Use the validate function to compare the results of the network before and after quantization using the validation data set. To validate the dlnetwork, define a dlquantizationOptions object and specify a custom metric function. The hComputeModelAccuracy metric function uses the classes from the training data to compare the predicted labels to the labels in the validation data.
dlquantOpts = dlquantizationOptions;
dlquantOpts.MetricFcn = {@(x)hComputeModelAccuracy(x,net,validationDataStore,classes)}dlquantOpts =
dlquantizationOptions with properties:
Validation Metric Info
MetricFcn: {@(x)hComputeModelAccuracy(x,net,validationDataStore,classes)}
Validation Environment Info
Target: 'host'
Bitstream: ''
valResults = validate(quantObj,validationDataStore,dlquantOpts);
Examine the MetricResults.Result field of the validation output to view the accuracy of the quantized network and the floating-point network.
valResults.MetricResults.Result
ans=2×2 table
NetworkImplementation MetricOutput
_____________________ ____________
{'Floating-Point'} 0.99
{'Quantized' } 0.99
Supporting Functions
Train Network
The trainDigitsNetwork function takes as input a CombinedDatastore, the network classes, and the training options, and trains the network using the trainnet function.
function net = trainDigitsNetwork(dsTrain, classes, options) % Define network imageInputSize = [28 28 1]; filterSize = 5; numFilters = 16; layers = [ imageInputLayer(imageInputSize,Normalization="none") convolution2dLayer(filterSize,numFilters) batchNormalizationLayer reluLayer fullyConnectedLayer(50) flattenLayer concatenationLayer(1,2,Name="cat") fullyConnectedLayer(numel(classes)) softmaxLayer]; lgraph = layerGraph(layers); featInput = featureInputLayer(1,Name="features"); lgraph = addLayers(lgraph,featInput); lgraph = connectLayers(lgraph,"features","cat/in2"); dlnet = dlnetwork(lgraph); net = trainnet(dsTrain, dlnet,"crossentropy", options); end
Mini-Batch Preprocessing Function
The preprocessMiniBatchTraining function preprocesses a mini-batch of predictors and labels for loss computation during training.
function [X1, X2, T] = preprocessMiniBatchTraining(X1Cell, X2Cell,TCell) % Concatenate. X1 = cat(4,X1Cell{1:end}); X2 = cat(1, X2Cell{1:end}); % Extract label data from cell and concatenate. T = cat(2,TCell{1:end}); % One-hot encode labels. T = onehotencode(T,1); end
Evaluate Model Accuracy
The modelAccuracy function takes as input a dlnetwork object, a minibatchqueue of input data mbq, the network classes, and the number of observations and returns the accuracy.
function accuracy = modelAccuracy(net, mbq, classes, numObservations) % This function computes the model accuracy of a dlnetwork on the minibatchque 'mbq'. totalCorrect = 0; classes = int32(categorical(classes)); reset(mbq); while hasdata(mbq) [dlX1, dlX2, Y] = next(mbq); dlYPred = extractdata(predict(net, dlX1, dlX2)); YPred = onehotdecode(dlYPred,classes,1)'; YReal = onehotdecode(Y,classes,1)'; miniBatchCorrect = nnz(YPred == YReal); totalCorrect = totalCorrect + miniBatchCorrect; end accuracy = totalCorrect / numObservations * 100; end
Model Predictions Function
The modelPredictions function takes as input a dlnetwork object, a minibatchqueue of input data mbq, the network classes, and computes the model predictions by iterating over all data in the minibatchqueue object. The function uses the onehotdecode function to find the predicted class with the highest score.
function YPred = modelPredictions(net, mbq, classes) YPred = []; reset(mbq); while hasdata(mbq) [dlX1, dlX2] = next(mbq); dlYPred = extractdata(predict(net, dlX1, dlX2)); currentYPred = onehotdecode(dlYPred,classes,1)'; YPred = cat(1, YPred, currentYPred); end end
Metric Function for Validation
The hComputeModelAccuracy metric function accepts as input the prediction scores, a dlnetwork object, a validation datastore, and the network classes. The function compares predicted labels to ground truth label data and returns the accuracy.
function accuracy = hComputeModelAccuracy(predictionScores, ~, dataStore, classes) %% Computes model-level accuracy statistics % Load ground truth. tmp = readall(dataStore); groundTruth = tmp(:,3); numGroundTruth = numel(groundTruth); predictionScores = reshape(predictionScores, [numel(predictionScores)/numGroundTruth numGroundTruth])'; % Compare predicted label with actual ground truth. predictionError = {}; for idx=1:numGroundTruth [~, idy] = max(predictionScores(idx,:)); yActual = classes(idy); predictionError{end+1} = (yActual == groundTruth{idx}); %#ok end % Sum all prediction errors. predictionError = [predictionError{:}]; accuracy = sum(predictionError)/numel(predictionError); end