Main Content

Object Detection Using YOLO v4 Deep Learning

This example shows how to detect objects in images using you only look once version 4 (YOLO v4) deep learning network. In this example, you will

  • Configure a dataset for training, validation, and testing of YOLO v4 object detection network. You will also perform data augmentation on the training dataset to improve the network efficiency.

  • Compute anchor boxes from the training data to use for training the YOLO v4 object detection network.

  • Create a YOLO v4 object detector by using the yolov4ObjectDetector function and train the detector using trainYOLOv4ObjectDetector function.

This example also provides a pretrained YOLO v4 object detector to use for detecting vehicles in an image. The pretrained network uses tiny-yolov4-coco as the backbone network and is trained on a vehicle dataset. For information about YOLO v4 object detection network, see Getting Started with YOLO v4 (Computer Vision Toolbox).

Load Dataset

This example uses a small vehicle dataset that contains 295 images. Many of these images come from the Caltech Cars 1999 and 2001 datasets, available at the Caltech Computational Vision website created by Pietro Perona and used with permission. Each image contain one or two labeled instances of a vehicle. A small dataset is useful for exploring the YOLO v4 training procedure, but in practice, more labeled images are needed to train a robust detector.

Unzip the vehicle images and load the vehicle ground truth data.

unzip vehicleDatasetImages.zip
data = load("vehicleDatasetGroundTruth.mat");
vehicleDataset = data.vehicleDataset;

The vehicle data is stored in a two-column table. The first column contain the image file paths and the second column contain the bounding boxes.

Display first few rows of the data set.

vehicleDataset(1:4,:)
ans=4×2 table
              imageFilename                   vehicle     
    _________________________________    _________________

    {'vehicleImages/image_00001.jpg'}    {[220 136 35 28]}
    {'vehicleImages/image_00002.jpg'}    {[175 126 61 45]}
    {'vehicleImages/image_00003.jpg'}    {[108 120 45 33]}
    {'vehicleImages/image_00004.jpg'}    {[124 112 38 36]}

Add the full path to the local vehicle data folder.

vehicleDataset.imageFilename = fullfile(pwd,vehicleDataset.imageFilename);

Split the dataset into training, validation, and test sets. Select 60% of the data for training, 10% for validation, and the rest for testing the trained detector.

rng("default");
shuffledIndices = randperm(height(vehicleDataset));
idx = floor(0.6 * length(shuffledIndices) );

trainingIdx = 1:idx;
trainingDataTbl = vehicleDataset(shuffledIndices(trainingIdx),:);

validationIdx = idx+1 : idx + 1 + floor(0.1 * length(shuffledIndices) );
validationDataTbl = vehicleDataset(shuffledIndices(validationIdx),:);

testIdx = validationIdx(end)+1 : length(shuffledIndices);
testDataTbl = vehicleDataset(shuffledIndices(testIdx),:);

Use imageDatastore and boxLabelDatastore to create datastores for loading the image and label data during training and evaluation.

imdsTrain = imageDatastore(trainingDataTbl{:,"imageFilename"});
bldsTrain = boxLabelDatastore(trainingDataTbl(:,"vehicle"));

imdsValidation = imageDatastore(validationDataTbl{:,"imageFilename"});
bldsValidation = boxLabelDatastore(validationDataTbl(:,"vehicle"));

imdsTest = imageDatastore(testDataTbl{:,"imageFilename"});
bldsTest = boxLabelDatastore(testDataTbl(:,"vehicle"));

Combine image and box label datastores.

trainingData = combine(imdsTrain,bldsTrain);
validationData = combine(imdsValidation,bldsValidation);
testData = combine(imdsTest,bldsTest);

Use validateInputData to detect invalid images, bounding boxes or labels when the data set contains one or more of the following:

  • Samples with invalid image format or NaN values

  • Bounding boxes containing zeros/NaN values/Inf values/empty

  • Missing or non-categorical labels

The values of the bounding boxes must be finite positive integers and must not be NaN. The height and the width of the bounding box values must be positive and lie within the image boundary.

validateInputData(trainingData);
validateInputData(validationData);
validateInputData(testData);

Display one of the training images and box labels.

data = read(trainingData);
I = data{1};
bbox = data{2};
annotatedImage = insertShape(I,"Rectangle",bbox);
annotatedImage = imresize(annotatedImage,2);
figure
imshow(annotatedImage)

reset(trainingData);

Create a YOLO v4 Object Detector Network

Specify the network input size to be used for training.

inputSize = [416 416 3];

Specify the name of the object class to detect.

className = "vehicle";

Use the estimateAnchorBoxes (Computer Vision Toolbox) function to estimate anchor boxes based on the size of objects in the training data. To account for the resizing of the images prior to training, resize the training data for estimating anchor boxes. Use the transform function to preprocess the training data, then define the number of anchor boxes and estimate the anchor boxes. Resize the training data to the input size of the network by using the preprocessData helper function.

rng("default")
trainingDataForEstimation = transform(trainingData,@(data)preprocessData(data,inputSize));
numAnchors = 6;
[anchors,meanIoU] = estimateAnchorBoxes(trainingDataForEstimation,numAnchors);

Specify the anchorBoxes argument as the anchor boxes to use in all the detection heads. The anchor boxes are specified as a cell array of [M x 1], where M denotes the number of detection heads. Each detection head consists of a [N x 2] matrix that is stored in the anchors argument, where N is the number of anchors to use. Specify the anchorBoxes for each detection head based on the feature map size. Use larger anchors at lower scale and smaller anchors at higher scale. To do so, sort anchors by area, in descending order, and assign the first three to the first detection head and the last three to the second detection head.

area = anchors(:, 1).*anchors(:,2);
[~,idx] = sort(area,"descend");

anchors = anchors(idx,:);
anchorBoxes = {anchors(1:3,:)
    anchors(4:6,:)};

For more information on choosing anchor boxes, see Estimate Anchor Boxes From Training Data (Computer Vision Toolbox) (Computer Vision Toolbox™) and Anchor Boxes for Object Detection (Computer Vision Toolbox).

Create the YOLO v4 object detector by using the yolov4ObjectDetector function. specify the name of the pretrained YOLO v4 detection network trained on COCO dataset. Specify the class name and the estimated anchor boxes.

detector = yolov4ObjectDetector("tiny-yolov4-coco",className,anchorBoxes,InputSize=inputSize);

Perform Data Augmentation

Perform data augmentation to improve training accuracy. Use the transform function to apply custom data augmentations to the training data. The augmentData helper function applies the following augmentations to the input data:

  • Color jitter augmentation in HSV space

  • Random horizontal flip

  • Random scaling by 10 percent

Note that data augmentation is not applied to the test and validation data. Ideally, test and validation data should be representative of the original data and is left unmodified for unbiased evaluation.

augmentedTrainingData = transform(trainingData,@augmentData);

Read and display samples of augmented training data.

augmentedData = cell(4,1);
for k = 1:4
    data = read(augmentedTrainingData);
    augmentedData{k} = insertShape(data{1},"rectangle",data{2});
    reset(augmentedTrainingData);
end
figure
montage(augmentedData,BorderSize=10)

Specify Training Options

Use trainingOptions to specify network training options. Train the object detector using the Adam solver for 80 epochs with a constant learning rate 0.001. To get trained detector with lowest validation loss, set OutputNetwork to "best-validation-loss". Set ValidationData to the validation data and ValidationFrequency to 1000. To validate the data more often, you can reduce the ValidationFrequency which also increases the training time. Use ExecutionEnvironment to determine what hardware resources will be used to train the network. The default value for ExecutionEnvironment is "auto", which selects a GPU if it is available, and otherwise selects the CPU. Set CheckpointPath to a temporary location to enable the saving of partially trained detectors during the training process. If training is interrupted, for instance by a power outage or system failure, you can resume training from the saved checkpoint.

options = trainingOptions("adam", ...
    GradientDecayFactor=0.9, ...
    SquaredGradientDecayFactor=0.999, ...
    InitialLearnRate=0.001, ...
    LearnRateSchedule="none", ...
    MiniBatchSize=4, ...
    L2Regularization=0.0005, ...
    MaxEpochs=80, ...
    DispatchInBackground=true, ...
    ResetInputNormalization=true, ...
    Shuffle="every-epoch", ...
    VerboseFrequency=20, ...
    ValidationFrequency=1000, ...
    CheckpointPath=tempdir, ...
    ValidationData=validationData, ...
    OutputNetwork="best-validation-loss");

Train YOLO v4 Object Detector

Use the trainYOLOv4ObjectDetector function to train YOLO v4 object detector. This example is run on an NVIDIA™ RTX A5000 with 24 GB of memory. Training this network took approximately 33 minutes using this setup. The training time will vary depending on the hardware you use. Instead of training the network, you can also use a pretrained YOLO v4 object detector in the Computer Vision Toolbox™.

Download the pretrained detector by using the downloadPretrainedYOLOv4Detector helper function. To train the detector on the augmented training data, set the doTraining value to true.

doTraining = false;
if doTraining       
    % Train the YOLO v4 detector.
    [detector,info] = trainYOLOv4ObjectDetector(augmentedTrainingData,detector,options);
else
    % Load pretrained detector for the example.
    detector = downloadPretrainedYOLOv4Detector();
end
Downloading pretrained detector...

Run the detector on a test image.

I = imread("highway.png");
[bboxes,scores,labels] = detect(detector,I);

Display the results.

I = insertObjectAnnotation(I,"rectangle",bboxes,scores);
figure
imshow(I)

Evaluate Detector Using Test Set

Evaluate the trained object detector on a large set of images to measure the performance. Computer Vision Toolbox™ provides an object detector evaluation function (evaluateObjectDetection (Computer Vision Toolbox)) to measure common metrics such as average precision and log-average miss rate. For this example, use the average precision metric to evaluate performance. The average precision provides a single number that incorporates the ability of the detector to make correct classifications (precision) and the ability of the detector to find all relevant objects (recall).

Run the detector on all the test images. Set the detection threshold to a low value to detect as many objects as possible. This helps you evaluate the detector precision across the full range of recall values.

detectionResults = detect(detector,testData,Threshold=0.01);

Evaluate the object detector using average precision metric.

metrics = evaluateObjectDetection(detectionResults,testData);
classID = 1;
precision = metrics.ClassMetrics.Precision{classID};
recall = metrics.ClassMetrics.Recall{classID};

The precision-recall (PR) curve highlights how precise a detector is at varying levels of recall. The ideal precision is 1 at all recall levels. The use of more data can help improve the average precision but might require more training time. Plot the PR curve.

figure
plot(recall,precision)
xlabel("Recall")
ylabel("Precision")
grid on
title(sprintf("Average Precision = %.2f",metrics.ClassMetrics.mAP(classID)))

Supporting Functions

Helper function for performing data augmentation.

function data = augmentData(A)
% Apply random horizontal flipping, and random X/Y scaling. Boxes that get
% scaled outside the bounds are clipped if the overlap is above 0.25. Also,
% jitter image color.

data = cell(size(A));
for ii = 1:size(A,1)
    I = A{ii,1};
    bboxes = A{ii,2};
    labels = A{ii,3};
    sz = size(I);

    if numel(sz) == 3 && sz(3) == 3
        I = jitterColorHSV(I,...
            contrast=0.0,...
            Hue=0.1,...
            Saturation=0.2,...
            Brightness=0.2);
    end
    
    % Randomly flip image.
    tform = randomAffine2d(XReflection=true,Scale=[1 1.1]);
    rout = affineOutputView(sz,tform,BoundsStyle="centerOutput");
    I = imwarp(I,tform,OutputView=rout);
    
    % Apply same transform to boxes.
    [bboxes,indices] = bboxwarp(bboxes,tform,rout,OverlapThreshold=0.25);
    labels = labels(indices);
    
    % Return original data only when all boxes are removed by warping.
    if isempty(indices)
        data(ii,:) = A(ii,:);
    else
        data(ii,:) = {I,bboxes,labels};
    end
end
end

function data = preprocessData(data,targetSize)
% Resize the images and scale the pixels to between 0 and 1. Also scale the
% corresponding bounding boxes.

for ii = 1:size(data,1)
    I = data{ii,1};
    imgSize = size(I);
    
    bboxes = data{ii,2};

    I = im2single(imresize(I,targetSize(1:2)));
    scale = targetSize(1:2)./imgSize(1:2);
    bboxes = bboxresize(bboxes,scale);
    
    data(ii,1:2) = {I,bboxes};
end
end

Helper function for downloading the pretrained YOLO v4 object detector.

function detector = downloadPretrainedYOLOv4Detector()
% Download a pretrained yolov4 detector.
if ~exist("yolov4TinyVehicleExample_24a.mat", "file")
    if ~exist("yolov4TinyVehicleExample_24a.zip", "file")
        disp("Downloading pretrained detector...");
        pretrainedURL = "https://ssd.mathworks.com/supportfiles/vision/data/yolov4TinyVehicleExample_24a.zip";
        websave("yolov4TinyVehicleExample_24a.zip", pretrainedURL);
    end
    unzip("yolov4TinyVehicleExample_24a.zip");
end
pretrained = load("yolov4TinyVehicleExample_24a.mat");
detector = pretrained.detector;
end

See Also

Apps

Functions

Objects

Related Topics