Borrar filtros
Borrar filtros

Overfitting deep neural network

28 visualizaciones (últimos 30 días)
Muhammad
Muhammad el 20 de Abr. de 2023
Comentada: Muhammad el 30 de Abr. de 2023
I am using CNN architecture resnet18 with transfer learning for classifications. Overfitting is heppenrd after trainging and testing the model.
Here is my code. Can anyone please tell me what chanfes I have to do in the below code. Please see the attached result file in which you can see the data overfitting is happening.
clear all
close all
imds = imageDatastore("D:\DatasetJPG", ...
'IncludeSubfolders',true,'LabelSource','foldernames');
[imdsTrain,imdsValidation] = splitEachLabel(imds,0.7); %70% for train 30% for test
net=resnet18; % for the first time,you have to download the package from Add-on explorer
%Replace Final Layers
numClasses = numel(categories(imdsTrain.Labels));
lgraph = layerGraph(net);
newFCLayer = fullyConnectedLayer(numClasses,'Name','new_fc','WeightLearnRateFactor',10,'BiasLearnRateFactor',10);
lgraph = replaceLayer(lgraph,'fc1000' ,newFCLayer);
newClassLayer = classificationLayer('Name','new_classoutput');
lgraph = replaceLayer(lgraph,'ClassificationLayer_predictions',newClassLayer);
%Train Network
inputSize = net.Layers(1).InputSize;
imageAugmenter = imageDataAugmenter( ...
'RandRotation',[-5,5], ...
'RandXTranslation',[-5 5], ...
'RandYTranslation',[-5 5]);
augimdsTrain = augmentedImageDatastore(inputSize(1:2),imdsTrain,'DataAugmentation',imageAugmenter);
augimdsValidation = augmentedImageDatastore(inputSize(1:2),imdsValidation);
options = trainingOptions('sgdm', ...
'MiniBatchSize',10, ...
'MaxEpochs',20, ...
'InitialLearnRate',1e-3, ...
'Shuffle','every-epoch', ...
'ValidationData',augimdsValidation, ...
'ValidationFrequency',5, ...
'Verbose',false, ...
'Plots','training-progress');
trainedNet = trainNetwork(augimdsTrain,lgraph,options);
YPred = classify(trainedNet,augimdsValidation);
accuracy = mean(YPred == imdsValidation.Labels)
C = confusionmat(imdsValidation.Labels,YPred)
cm = confusionchart(imdsValidation.Labels,YPred);
cm.Title = 'Confusion Matrix for Validation Data';
cm.ColumnSummary = 'column-normalized';
cm.RowSummary = 'row-normalized';

Respuesta aceptada

Sugandhi
Sugandhi el 28 de Abr. de 2023
Editada: Sugandhi el 28 de Abr. de 2023
Hi Muhammad,
I understand that you are using CNN architecture resnet18 with transfer learning for classifications. Overfitting is happened after trainging and testing the model.
Based on the code you provided, here are some workarounds to address the issue of overfitting in your ResNet-18 CNN model:
  1. Increase the amount of data augmentation: Data augmentation is a technique that artificially increases the size of your dataset by applying random transformations to the images during training. It helps in introducing variability in the data, making the model more robust to overfitting. You can try increasing the amount of data augmentation by adding more random transformations such as horizontal flipping, vertical flipping, and changing brightness/contrast.
  2. Use dropout regularization: Dropout is a regularization technique that randomly sets a fraction of the input units to 0 at each update during training, which helps in preventing the model from relying too heavily on certain features and encourages it to learn more generalized representations. You can add a dropout layer after the fully connected layer in your model by using the dropoutLayer function from MATLAB's Deep Learning Toolbox.
  3. Reduce the learning rate: A high learning rate can cause the model to overshoot the optimal weights during training, leading to overfitting. You can try reducing the initial learning rate in your trainingOptions function, for example, by setting it to a lower value such as 1e-4 or 1e-5.
  4. Use early stopping: Early stopping is a technique that monitors the validation loss during training and stops the training process if the validation loss starts to increase, indicating overfitting. You can add the EarlyStopping option in your trainingOptions function and set it to a reasonable value, such as 5 or 10, to stop training early if needed.
  5. Add more training data: Overfitting can occur when the model is not exposed to enough diverse training data. You can consider increasing the size of your training dataset by collecting more data, or by using data augmentation techniques to generate synthetic data.
  6. Try using a smaller model: ResNet-18 is a relatively deep model with a large number of parameters, which can make it more prone to overfitting, especially when the training dataset is small. You can try using a smaller CNN architecture, such as ResNet-9 or a custom architecture with fewer layers, to see if it helps in reducing overfitting.
  7. Regularize the fully connected layers: You can add weight regularization techniques, such as L1 or L2 regularization, to the fully connected layers in your model to prevent overfitting. You can use the fullyConnectedLayer function's WeightRegularization and BiasRegularization options to specify the type and strength of regularization to apply.
Implementing these changes can help in reducing overfitting in your ResNet-18 model and improving its generalization performance.
  1 comentario
Muhammad
Muhammad el 30 de Abr. de 2023
Can you please help me to modify my code to add a dropout regularization and early stopping?

Iniciar sesión para comentar.

Más respuestas (0)

Categorías

Más información sobre Image Data Workflows en Help Center y File Exchange.

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!

Translated by