Why wont my LSTM training not surpass a 51.20% accuracy?

8 visualizaciones (últimos 30 días)
Kimberly Cardillo
Kimberly Cardillo el 20 de Ag. de 2020
Respondida: Aditya Patil el 21 de Sept. de 2020
I am running an LSTM neural network. I have two inputs and want two output classifications. The two inputs are crack-related signals (top image) and noise signals (bottom image). I have 125 signals as my training data and 59 signals as my testing data.
The network layers and training options are seen in the lines of code below. No matter what training options I change ('sgdm' vs. 'adam', # of max epochs, initial learn rate, etc.) I consistantly get a training accuracy of 51.20%. A screenshot of my training process can be seen under the code and a confusion matrix for my testing data can be seen below the training process image. Originally, I was working with only 27 training signals and 27 testing signals and was getting an accuracy of only 62% and I thought that maybe I just didn't have enough data but after adding more data, my training accuracy went down.
inLayer = sequenceInputLayer(1);
lstm = bilstmLayer(100,'OutputMode','last');
outLayers = [
fullyConnectedLayer(2)
softmaxLayer()
classificationLayer()
];
layers = [inLayer;lstm;outLayers];
options = trainingOptions('sgdm', ...
'MaxEpochs',8, ...
'InitialLearnRate',0.05, ...
'Plots','training-progress');
net = trainNetwork(dataTrain,fTrain,layers,options);
Does anyone know what I can do to improve my accuracy?

Respuestas (1)

Aditya Patil
Aditya Patil el 21 de Sept. de 2020
From the confusion matrix, we can see that the model predicts all data as noise. This generally happens when model cannot find any correlation between the input and output, and it overfits by predicting whichever class that has highest number of samples.
There are few options that you may try, for example,
  1. Try preprocessing the dataset. Using fourier transform might be an option depending on the data.
  2. Try starting with a simpler network to see if it works any better. That may suggest issues with your model.
  3. See if there is any pretrained model available for the task, and then optimize it for your dataset using transfer learning.

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!

Translated by