- You are using the 'Weights','none' option when loading the pre-trained network. This means that the weights of the network will be randomly initialized instead of using the pre-trained weights. If you want to use the pre-trained weights, you should remove this option.
- In the line lys = net;, you are assigning the net to lys, but then you are not using lys anywhere else in the code. You can remove this line.
- The line numClasses = numel(categories(imdsTrain.Labels)); is unnecessary because you have already defined the number of classes explicitly as 3 in the line newFCLayer = fullyConnectedLayer(3,'Name','new_fc',...). You can remove the numClasses line.
- In the line lgraph = replaceLayer(lgraph,'output',newClassLayer);, it seems like you intended to replace the last classification layer with newClassLayer. However, in AlexNet, the last classification layer is fc8, not output. You should replace 'output' with 'fc8' in that line.
- In the testing process, when loading the image with I = imread('plastic73.jpg');, make sure that the image file plastic73.jpg is in the current working directory or provide the full path to the image file.
Image classification based on transfer learning using Alexnet
12 visualizaciones (últimos 30 días)
Mostrar comentarios más antiguos
Below is my full coding. Did the coding to perform transfer learning using alexnet correct?
clc
clear all;
close all;
outputFolder=fullfile('recycle101');
rootFolder=fullfile(outputFolder,'project');
categories={'aluminiumcan','petbottle','drinkcartonbox'};
imds=imageDatastore(fullfile(rootFolder,categories),'LabelSource','foldernames');
tbl=countEachLabel(imds)
minSetCount=min(tbl{:,2});
[imdsTrain,imdsValidation] = splitEachLabel(imds,0.7,'randomized');
countEachLabel(imds);
%randomly choose file for aluminium can, PET bottles, and drink carton box
AluminiumCan=find(imds.Labels=='aluminiumcan',1);
PETBottles=find(imds.Labels=='petbottle',1);
DrinkCartonBox=find(imds.Labels=='drinkcartonbox',1);
%plot image that was pick randomly
figure
subplot(2,2,1);
imshow(readimage(imds,AluminiumCan));
subplot(2,2,2);
imshow(readimage(imds,PETBottles));
subplot(2,2,3);
imshow(readimage(imds,DrinkCartonBox));
%Load pre-trained network
%net=alexnet;
net =alexnet('Weights','none');
analyzeNetwork(net)
lys = net;
lys(end-3:end)
numClasses = numel(categories(imdsTrain.Labels));
lgraph = layerGraph(lys);
%Replace the classification layers for new task
newFCLayer = fullyConnectedLayer(3,'Name','new_fc','WeightLearnRateFactor',10,'BiasLearnRateFactor',10);
lgraph = replaceLayer(lgraph,'fc8',newFCLayer);
newClassLayer = classificationLayer('Name','new_classoutput');
lgraph = replaceLayer(lgraph,'output',newClassLayer);
%Resize the image
imageSize=lys(1).InputSize;
augmentedTrainingSet=augmentedImageDatastore(imageSize,...
imdsTrain,'ColorPreprocessing','gray2rgb');
augmentedValidateSet=augmentedImageDatastore(imageSize,...
imdsValidation,'ColorPreprocessing','gray2rgb');
options = trainingOptions('sgdm', ...
'MiniBatchSize',4, ...
'MaxEpochs',8, ...
'InitialLearnRate',1e-4, ...
'Shuffle','every-epoch', ...
'ValidationData',augmentedValidateSet, ...
'ValidationFrequency',3, ...
'Verbose',false, ...
'ExecutionEnvironment','cpu', ...
'Plots','training-progress');
trainedNet = trainNetwork(augmentedTrainingSet,lgraph,options);
[YPred,probs] = classify(trainedNet,augmentedValidateSet,'ExecutionEnvironment', 'cpu');
YValidation = imdsValidation.Labels;
%Calculate accuracy,error,pecision and recall
accuracy = sum(YPred == YValidation)/numel(YValidation)
error=1-accuracy
confMat=confusionmat(YValidation ,YPred);
confMat=bsxfun(@rdivide,confMat,sum(confMat,2));
z=mean(diag(confMat))
cmt=confMat
sum_of_row= sum(cmt,2)
z=(diag(cmt));
precision = z./sum_of_row ;
overallprecision=mean(precision);
sum_of_column= sum(cmt,1)
recall=z./sum_of_column'
totalrecall=mean(recall);
f1=2*(overallprecision*totalrecall)/(overallprecision+totalrecall)
categories = {'aluminium can','PET bottle','drink carton box'}
label = categorical(categories)
cm = confusionchart(cmt,label)
cm.RowSummary = 'row-normalized';
cm.ColumnSummary = 'column-normalized';
%%save Network
save simpleDL.mat trainedNet lgraph
%% Testing process
I = imread('plastic73.jpg');
ds=augmentedImageDatastore(imageSize,...
I,'ColorPreprocessing','gray2rgb');
predictedLabel = trainedNet.classify(ds);
sprintf('The loaded image belongs to %s class', predictedLabel)
0 comentarios
Respuestas (1)
Shaik
el 14 de Mayo de 2023
Hi,
The code you provided is mostly correct for performing transfer learning using the AlexNet architecture. However, there are a few points to note:
Apart from these points, the overall structure of the code for transfer learning using AlexNet looks fine.
Ver también
Categorías
Más información sobre Image Data Workflows en Help Center y File Exchange.
Community Treasure Hunt
Find the treasures in MATLAB Central and discover how the community can help you!
Start Hunting!