Data augmentation in CNN
8 visualizaciones (últimos 30 días)
Mostrar comentarios más antiguos
Srinidhi Gorityala
el 21 de Mayo de 2020
Respondida: Srivardhan Gadila
el 29 de Mayo de 2020
Heloo... Iam working on a data set of 300 images of 2 classes. I want to perform the data augmentation. Below is the code i have attached for data augmentation. According to the methodology the accuracy must increase, but in our case it is decreasing from 96% to 87%. could anyone please help me in solving thia problem.
Thank you in advance:)
clc;
clear all;
close all;
myTrainingFolder = 'C:\Users\Admin\Desktop\Major Project\cnn_dataset';
%testingFolder = 'C:\Users\Be Happy\Documents\MATLAB\gtsrbtest';
imds = imageDatastore(myTrainingFolder,'IncludeSubfolders', true, 'LabelSource', 'foldernames');
%testingSet = imageDatastore(testingFolder,'IncludeSubfolders', true, 'LabelSource', 'foldernames');
labelCount = countEachLabel(imds);
numClasses = height(labelCount);
numImagesTraining = numel(imds.Files);
%% Create training and validation sets
[imdsTrainingSet, imdsValidationSet] = splitEachLabel(imds, 0.7, 'randomize');
%% Build a simple CNN
imageSize = [227 227 3];
% Specify the convolutional neural network architecture.
layers = [
imageInputLayer(imageSize)
convolution2dLayer(3,8,'Padding','same')
batchNormalizationLayer
reluLayer
maxPooling2dLayer(2,'Stride',2)
convolution2dLayer(3,16,'Padding','same')
batchNormalizationLayer
reluLayer
maxPooling2dLayer(2,'Stride',2)
convolution2dLayer(3,32,'Padding','same')
batchNormalizationLayer
reluLayer
fullyConnectedLayer(numClasses)
softmaxLayer
classificationLayer];
%% Specify training options
options = trainingOptions('sgdm', ...
'MiniBatchSize',10, ...
'MaxEpochs',6, ...
'InitialLearnRate',1e-4, ...
'Shuffle','every-epoch', ...
'ValidationData',imdsValidationSet, ...
'ValidationFrequency',3, ...
'Verbose',false, ...
'Plots','training-progress');
%% Train the network
net1 = trainNetwork(imdsTrainingSet,layers,options);
analyzeNetwork(net1);
%% Report accuracy of baseline classifier on validation set
YPred = classify(net1,imdsValidationSet);
YValidation = imdsValidationSet.Labels;
imdsAccuracy = sum(YPred == YValidation)/numel(YValidation);
%% Plot confusion matrix
figure, plotconfusion(YValidation,YPred)
%% PART 2: Baseline Classifier with Data Augmentation
%% Create augmented image data store
% Specify data augmentation options and values/ranges
imageAugmenter = imageDataAugmenter( ...
'RandRotation',[-20,20], ...
'RandXTranslation',[-5 5], ...
'RandYTranslation',[-5 5]);
% Apply transformations (using randomly picked values) and build augmented
% data store
augImds = augmentedImageDatastore(imageSize,imdsTrainingSet, ...
'DataAugmentation',imageAugmenter);
% (OPTIONAL) Preview augmentation results
batchedData = preview(augImds);
figure, imshow(imtile(batchedData.input))
%% Train the network.
net2 = trainNetwork(augImds,layers,options);
%% Report accuracy of baseline classifier with image data augmentation
YPred = classify(net2,imdsValidationSet);
YValidation = imdsValidationSet.Labels;
augImdsAccuracy = sum(YPred == YValidation)/numel(YValidation);
%% Plot confusion matrix
figure, plotconfusion(YValidation,YPred)
2 comentarios
Respuesta aceptada
Srivardhan Gadila
el 29 de Mayo de 2020
The following are few suggestions:
- Make sure that all the classes have equal number of observations.
- Check how trainNetwork uses an augmented image datastore to transform training data for each epoch: Augment Images for Training with Random Geometric Transformations. Then try training the network with augmentedImageDatastore for more epochs.
- Try changing the network architecture itself if there is no improvement in the accuracy when augmentedImageDatastore is used. You can refer to Choose Network Architecture.
- Try Using dropout layers & increasing global L2 regularization factor in new architecture. For more information, see dropoutLayer & 'L2Regularization' option in trainingOptions.
0 comentarios
Más respuestas (0)
Ver también
Categorías
Más información sobre Recognition, Object Detection, and Semantic Segmentation en Help Center y File Exchange.
Community Treasure Hunt
Find the treasures in MATLAB Central and discover how the community can help you!
Start Hunting!