How to add new classes to a neural network?
2 visualizaciones (últimos 30 días)
Mostrar comentarios más antiguos
I made myself a network for flowers recognition. It's pretty much a copy of Alex net, but with some layers deleted. I trained it with 5 classes, but now i want to add more. How can i do that without retrain it from 0?
allImages = imageDatastore('D:\stuff machine learning\flowers', 'IncludeSubfolders', true,... 'LabelSource', 'foldernames');
[trainingImages, testImages] = splitEachLabel(allImages, 0.8, 'randomize');
conv1 = convolution2dLayer(11,96,'Stride',4,'Padding',0); %290.5k neuroni conv2 = convolution2dLayer(5,256,'Stride',1,'Padding',2); %7milioane neuroni conv3 = convolution2dLayer(3,384,'Stride',1,'Padding',1); conv4 = convolution2dLayer(3,384,'Stride',1,'Padding',1); conv5 = convolution2dLayer(3,256,'Stride',1,'Padding',1); layers = [... imageInputLayer([227 227 3]); conv1; reluLayer('Name','relu1'); maxPooling2dLayer(3,'Name','pool1','Stride',2); conv2; reluLayer('Name','relu2'); maxPooling2dLayer(3,'Name','pool2','Stride',2); conv3; reluLayer('Name','relu3'); conv4; reluLayer('Name','relu4'); conv5; reluLayer('Name','relu5'); maxPooling2dLayer(3,'Name','pool5','Stride',2); fullyConnectedLayer(4096,'Name','fc6'); reluLayer('Name','relu6'); dropoutLayer('Name','drop6'); fullyConnectedLayer(4096,'Name','fc7'); reluLayer('Name','relu7'); dropoutLayer('Name','drop7'); fullyConnectedLayer(5,'Name','fc8'); softmaxLayer('Name','prob'); classificationLayer('Name','output');]
opts = trainingOptions('sgdm', ... 'InitialLearnRate', 0.001, ... 'LearnRateSchedule', 'piecewise', ... 'LearnRateDropFactor', 0.1, ... 'LearnRateDropPeriod', 10, ... 'L2Regularization', 0.008, ... 'MaxEpochs', 30, ... 'MiniBatchSize', 40, ... 'ValidationData',testImages, ... 'Verbose', true,... 'Plot','training-progress');
testImages.ReadFcn = @readFunctionTrain1; trainingImages.ReadFcn = @readFunctionTrain1; %antrenarea retelei myNet = trainNetwork(trainingImages, layers, opts);
[YPred,probs] = classify(myNet,testImages); accuracy = mean(YPred == testImages.Labels)
idx = randperm(numel(testImages.Files),4); figure for i = 1:4 subplot(2,2,i) I = readimage(testImages,idx(i)); imshow(I) label = YPred(idx(i)); title(string(label) + ", " + num2str(100*max(probs(idx(i),:)),3) + "%"); end
This is the network
1 comentario
Balakrishnan Rajan
el 16 de Oct. de 2018
I am trying to do the same thing. Theoretically this should be done by changing the dimension of the Weights matrix, Bias vector and the OutputSize of the fully connected layer and the OutputSize of the classoutput layer and add the new category label to the Classes object. However, these properties are set to read-only.
Peter Gadfort provided a solution in this thread. However, I cant change the OutputSize as this is still a read-only property. If you do find a solution, please post it.
The code I am trying is this:
% Adding new classes to a trained net
%%Create an editable net object
load('BestNet.mat')
TempNet = net.saveobj;
%%Edit the properties of the fully connected layer
FCLayer = TempNet.Layers(142,1);
FCOutputSize = FCLayer.OutputSize;
FCLayer.OutputSize = FCOutputSize+1;
FCWeights = FCLayer.Weights;
FCWsize = size(FCWeights)
FCLayer.Weights = rand(FCWsize(1)+1, FCWsize(2));
FCLayer.Weights(1:FCWsize(1),:) = FCWeights;
FCBias = FCLayer.Bias;
FCLayer.Bias = rand(size(FCBias)+1);
FCLayer.Bias(1:size(FCBias)) = FCBias;
%%Edit the properties of the output layer
OutputLayer = TempNet.Layers(144,1);
OLOutputSize = OutputLayer.OutputSize;
OutputLayer.OutputSize = OLOutputSize + 1;
OLClasses = OutputLayer.Classes;
OLClasses(size(OLClasses)+1) = 'Obstructed';
%%Make this the net
net = load.obj(TempNet);
The pretrained net that I am using is the GoogLeNet derivative with the last three layers changed to a fully connected layer, a softmax layer followed by a crossentropy loss. I am adding a new class called "obstructed". Alphabetically sorted, this is the last class which is why I add the new elements to the end of the older elements.
Respuestas (0)
Ver también
Categorías
Más información sobre Image Data Workflows en Help Center y File Exchange.
Productos
Community Treasure Hunt
Find the treasures in MATLAB Central and discover how the community can help you!
Start Hunting!