trainNetwork function features dimensions problem
Mostrar comentarios más antiguos
I'm working on the implementation of a LSTM classification model.
As input I have different time series, as output some categorical values (labels).
This is my code:
close all
clearvars -except ts data labels
clc
for ii = 1 : numel(ts)
timeVec = datenum(ts(ii).Timetable.Time);
data{ii} = [timeVec, data{ii}];
end
numChannels = size(data{1}, 2);
numHiddenUnits = 120;
numClasses = 2;
layers = [
sequenceInputLayer(numChannels)
lstmLayer(numHiddenUnits,'OutputMode','last')
fullyConnectedLayer(numClasses)
softmaxLayer
classificationLayer];
maxEpochs = 200;
miniBatchSize = 27;
options = trainingOptions('adam', ...
'ExecutionEnvironment','cpu', ...
'MaxEpochs', maxEpochs, ...
'MiniBatchSize', miniBatchSize, ...
'GradientThreshold', 1, ...
'Verbose', false, ...
'Plots', 'training-progress');
idxTrain = [1 2 3 9 10 11];
idxTest = [4 7 8];
Xtrain = cell(numel(idxTrain), 1);
Ytrain = categorical(labels(idxTrain))';
for ii = 1 : numel(idxTrain)
Xtrain{ii} = data{idxTrain(ii)};
Ytrain(ii) = categorical(labels(idxTrain(ii)));
end
Xtest = cell(numel(idxTest), 1);
Ytest = categorical(labels(idxTest))';
for ii = 1 : numel(idxTest)
Xtest{ii} = data{idxTest(ii)};
Ytest(ii) = categorical(labels(idxTest(ii)));
end
net = trainNetwork(Xtrain', Ytrain, layers, options);
Running, I receive this error:
Error using trainNetwork
Invalid training data. Predictors must be a N-by-1 cell array of sequences, where N is the
number of sequences. All sequences must have the same feature dimension and at least one time
step.
Error in main (line 244)
net = trainNetwork(Xtrain', Ytrain, layers, options);
I think it is related to the different sizes of the matrices inside each of the Xtrain cells.
Here a screenshot with the dimensions of the different cells of Xtrain.

Are there any way to train the model using inputs with different dimensions?
5 comentarios
Ganesh
el 20 de Jun. de 2024
Ideally, you should consider combining all the rows.
If each of the rows of Xtrain represents a different "type" of data, you might want to consider training a different model for each row. If all the rows represent the same kind, and there is a corresponding Ytrain value, I think it will be more justified to combine the rows and have a larger training data
Marco
el 20 de Jun. de 2024
Ayush Modi
el 20 de Jun. de 2024
Hi Marco,
You are facing the issue because the training input should be in N x 1 format. Your cell has multiple entries. As suggested by @Ganesh best practice is to combine all the rows. In case the data represents different things it is advisable to train different models.
However, in case you need to train a single model per cell item. You can try training the model per cell. Saving the model's paramters. Loading the model again for the next cell data and training the same model on next cell item.
Ganesh
el 20 de Jun. de 2024
@Marco, How would you differentiate each of the "time" frame? Or let's say you achieve the training, and now you want to make a prediction. You give it a set of 12 columns as input, and now the model would be confused as to which one of the 6 time frames you are referring to.
Marco
el 20 de Jun. de 2024
Respuestas (1)
Ayush Aniket
el 20 de Jun. de 2024
0 votos
Hi Marco,
The input format required by trainNetwork function in MATLAB for dataset of sequences is a Nx1 cell array where each element is a c-by-s matrix, where c is the number of features of the sequence and s is the sequence length. Refer to the following document link to read about various input formats: https://www.mathworks.com/help/deeplearning/ref/trainnetwork.html#mw_36a68d96-8505-4b8d-b338-44e1efa9cc5e
From the attached screenshot, it seems that Xtrain has the format as s-by-c matrix. You should reconstruct it in the required format.
Also, you do need to have all the sequences of same length. The software internally applies padding. You can read about it here: https://www.mathworks.com/help/deeplearning/ug/classify-sequence-data-using-lstm-networks.html#ClassifySequenceDataUsingLSTMNetworksExample-2
Note - trainNetwork function is not recommended anymore. You can use the trainnet function instead. To read about the input formats for sequence datsets in trainnet function, refer the following link:
Categorías
Más información sobre Deep Learning Toolbox en Centro de ayuda y File Exchange.
Community Treasure Hunt
Find the treasures in MATLAB Central and discover how the community can help you!
Start Hunting!