How to change output classes for image classification in alexnet?

13 visualizaciones (últimos 30 días)
Hello everyone
I'm trying to implement the following example:
but I got this error:
Error using trainNetwork (line 170)
Invalid training data. The output size (28224) of the last layer does not match the number of classes (24).
and this is my code:
tic
clear;
clc;
outputFolder = fullfile('Spilt_Dataset');
rootFolder = fullfile(outputFolder , 'Categories');
categories = {'front_or_left' , 'front_or_right','hump' , 'left_turn','narrow_from_left' , 'narrows_from_right',...
'no_horron' , 'no_parking','no_u_turn' , 'overtaking_is_forbidden','parking' , 'pedestrian_crossing',...
'right_or_left' , 'right_turn','rotor' , 'slow','speed_30' , 'speed_40',...
'speed_50','speed_60','speed_80','speed_100','stop','u_turn'};
imds = imageDatastore(fullfile(rootFolder,categories),'LabelSource','Foldernames');
tbl = countEachLabel(imds);
minSetCount = min(tbl{:,2});
imds = splitEachLabel(imds, minSetCount, 'randomize');
countEachLabel(imds);
front_or_left = find(imds.Labels == 'front_or_left',1);
front_or_right = find(imds.Labels == 'front_or_right',1);
hump = find(imds.Labels == 'hump',1);
left_turn = find(imds.Labels == 'left_turn',1);
narrow_from_left = find(imds.Labels == 'narrow_from_left',1);
narrows_from_right = find(imds.Labels == 'narrows_from_right',1);
no_horron = find(imds.Labels == 'no_horron',1);
no_parking = find(imds.Labels == 'no_parking',1);
no_u_turn = find(imds.Labels == 'no_u_turn',1);
overtaking_is_forbidden = find(imds.Labels == 'overtaking_is_forbidden',1);
parking = find(imds.Labels == 'parking',1);
pedestrian_crossing = find(imds.Labels == 'pedestrian_crossing',1);
right_or_left = find(imds.Labels == 'right_or_left',1);
right_turn = find(imds.Labels == 'right_turn',1);
rotor = find(imds.Labels == 'rotor',1);
slow = find(imds.Labels == 'slow',1);
speed_30 = find(imds.Labels == 'speed_30',1);
speed_40 = find(imds.Labels == 'speed_40',1);
speed_50 = find(imds.Labels == 'speed_50',1);
speed_60 = find(imds.Labels == 'speed_60',1);
speed_80 = find(imds.Labels == 'speed_80',1);
speed_100 = find(imds.Labels == 'speed_100',1);
stop = find(imds.Labels == 'stop',1);
u_turn = find(imds.Labels == 'u_turn',1);
[imdsTrain,imdsValidation] = splitEachLabel(imds,0.7,'randomized');
net = googlenet;
inputSize = net.Layers(1).InputSize;
lgraph = layerGraph(net);
numClasses = numel(categories(imdsTrain.Labels))
newLearnableLayer = fullyConnectedLayer(numClasses, ...
'Name','new_fc', ...
'WeightLearnRateFactor',10, ...
'BiasLearnRateFactor',10);
lgraph = replaceLayer(lgraph,'loss3-classifier',newLearnableLayer);
newClassLayer = classificationLayer('Name','new_classoutput');
lgraph = replaceLayer(lgraph,'output',newClassLayer);
pixelRange = [-30 30];
imageAugmenter = imageDataAugmenter( ...
'RandXReflection',true, ...
'RandXTranslation',pixelRange, ...
'RandYTranslation',pixelRange);
augimdsTrain = augmentedImageDatastore(inputSize(1:2),imdsTrain, ...
'DataAugmentation',imageAugmenter);
augimdsValidation = augmentedImageDatastore(inputSize(1:2),imdsValidation);
options = trainingOptions('sgdm', ...
'MiniBatchSize',10, ...
'MaxEpochs',6, ...
'InitialLearnRate',1e-4, ...
'Shuffle','every-epoch', ...
'ValidationData',augimdsValidation, ...
'ValidationFrequency',3, ...
'Verbose',false, ...
'Plots','training-progress');
netTransfer = trainNetwork(augimdsTrain,lgraph,options);

Respuesta aceptada

Srivardhan Gadila
Srivardhan Gadila el 16 de Abr. de 2021
If the classes listed from line "front_or_left = find(imd...." in the above code are the only classes which are 24 in total then the issue might be that the value returned by
numClasses = numel(categories(imdsTrain.Labels))
is probably 28224 and not 24. If that's the case then specify the value 24 directly in the line
newLearnableLayer = fullyConnectedLayer(24, ...
and try running the code.
You can try debugging the issue by checking the layer activation size by using analyzeNetwork function.
If the network is fine then the issue might be w.r.t your data i.e., labels itself.
If you still face the issue then Contact MathWorks Support.

Más respuestas (0)

Categorías

Más información sobre Image Data Workflows en Help Center y File Exchange.

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!

Translated by