Main Content

Train A Semantic Segmentation Network

Load the training data.

dataSetDir = fullfile(toolboxdir('vision'),'visiondata','triangleImages');
imageDir = fullfile(dataSetDir,'trainingImages');
labelDir = fullfile(dataSetDir,'trainingLabels');

Create an image datastore for the images.

imds = imageDatastore(imageDir);

Create a pixelLabelDatastore for the ground truth pixel labels.

classNames = ["triangle","background"];
labelIDs   = [255 0];
pxds = pixelLabelDatastore(labelDir,classNames,labelIDs);

Visualize training images and ground truth pixel labels.

I = read(imds);
C = read(pxds);

I = imresize(I,5);
L = imresize(uint8(C{1}),5);
imshowpair(I,L,'montage')

Create a semantic segmentation network. This network uses a simple semantic segmentation network based on a downsampling and upsampling design.

numFilters = 64;
filterSize = 3;
numClasses = 2;
layers = [
    imageInputLayer([32 32 1])
    convolution2dLayer(filterSize,numFilters,'Padding',1)
    reluLayer()
    maxPooling2dLayer(2,'Stride',2)
    convolution2dLayer(filterSize,numFilters,'Padding',1)
    reluLayer()
    transposedConv2dLayer(4,numFilters,'Stride',2,'Cropping',1);
    convolution2dLayer(1,numClasses);
    softmaxLayer()
    pixelClassificationLayer()
    ];

Setup training options.

opts = trainingOptions('sgdm', ...
    'InitialLearnRate',1e-3, ...
    'MaxEpochs',100, ...
    'MiniBatchSize',64);

Combine the image and pixel label datastore for training.

trainingData = combine(imds,pxds);

Train the network.

net = trainNetwork(trainingData,layers,opts);
Training on single CPU.
Initializing input data normalization.
|========================================================================================|
|  Epoch  |  Iteration  |  Time Elapsed  |  Mini-batch  |  Mini-batch  |  Base Learning  |
|         |             |   (hh:mm:ss)   |   Accuracy   |     Loss     |      Rate       |
|========================================================================================|
|       1 |           1 |       00:00:00 |       58.11% |       1.3458 |          0.0010 |
|      17 |          50 |       00:00:12 |       97.30% |       0.0924 |          0.0010 |
|      34 |         100 |       00:00:24 |       98.09% |       0.0575 |          0.0010 |
|      50 |         150 |       00:00:37 |       98.56% |       0.0424 |          0.0010 |
|      67 |         200 |       00:00:49 |       98.48% |       0.0435 |          0.0010 |
|      84 |         250 |       00:01:02 |       98.66% |       0.0363 |          0.0010 |
|     100 |         300 |       00:01:14 |       98.90% |       0.0310 |          0.0010 |
|========================================================================================|
Training finished: Reached final iteration.

Read and display a test image.

testImage = imread('triangleTest.jpg');
imshow(testImage)

Segment the test image and display the results.

C = semanticseg(testImage,net);
B = labeloverlay(testImage,C);
imshow(B)