Borrar filtros
Borrar filtros

training options in a MLP created with deep learning toolbox

9 visualizaciones (últimos 30 días)
Alberto Tellaeche
Alberto Tellaeche el 21 de Nov. de 2022
Respondida: Sai Kiran el 22 de Dic. de 2022
Hi all,
I am trying to train a stadard MLP created with deep learning toolbox to classify the digits in the MNIST dataset.
As for my example I do not want to use a CNN, I have flattened the image data, creating with each 28*28 image an input vector of 784 elements.
My code is as follows:
clear; clc;
filenameImagesTrain = 'train-images-idx3-ubyte.gz';
filenameLabelsTrain = 'train-labels-idx1-ubyte.gz';
filenameImagesTest = 't10k-images-idx3-ubyte.gz';
filenameLabelsTest = 't10k-labels-idx1-ubyte.gz';
XTrain = processImagesMNIST(filenameImagesTrain);
YTrain = processLabelsMNIST(filenameLabelsTrain);
XTest = processImagesMNIST(filenameImagesTest);
YTest = processLabelsMNIST(filenameLabelsTest);
sizeX = 28;
sizeY = 28;
XVectorTrain = reshape(XTrain, 28*28, 60000);
multilayer_perceptron = [
sequenceInputLayer(sizeX*sizeY,"Name","input")
fullyConnectedLayer(32,"Name","capa 1")
reluLayer("Name","relu")
fullyConnectedLayer(10,"Name","capa 2")
softmaxLayer("Name","softmax")
classificationLayer("Name","classoutput")];
plot(layerGraph(multilayer_perceptron));
options = trainingOptions("sgdm","Plots","training-progress", ...
"SequenceLength",sizeY*sizeX,...
"MaxEpochs",40,"MiniBatchSize",1, ...
"InitialLearnRate", 0.005,"Momentum",0.9, ...
"ExecutionEnvironment","auto");
%"MaxEpochs",40,"MiniBatchSize",8, ...
%"Shuffle","every-epoch", ...
network = trainNetwork(XVectorTrain,categorical(YTrain),multilayer_perceptron,options);
However, I do not get the network to train:
And this is the result until the end.
I know this example can be done, I have seen similar examples in Keras, but I can not make it work in MATLAB.
I would be very grateful if someone could help me with this issue,
Best regards,

Respuestas (1)

Sai Kiran
Sai Kiran el 22 de Dic. de 2022
Hi,
The function trainNetwork returns the trained model. The returned model gets stored in the variable network, and can be used for predicting the results.
Please refer to the following documentation to know how to use the trained model in further steps.
I hope it resolves your query.
Regards,
Sai Kiran Ratna

Categorías

Más información sobre Deep Learning Toolbox en Help Center y File Exchange.

Productos


Versión

R2022a

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!

Translated by