필터 지우기
필터 지우기

Overfitting deep neural network

조회 수: 28 (최근 30일)
Muhammad
Muhammad 2023년 4월 20일
댓글: Muhammad 2023년 4월 30일
I am using CNN architecture resnet18 with transfer learning for classifications. Overfitting is heppenrd after trainging and testing the model.
Here is my code. Can anyone please tell me what chanfes I have to do in the below code. Please see the attached result file in which you can see the data overfitting is happening.
clear all
close all
imds = imageDatastore("D:\DatasetJPG", ...
'IncludeSubfolders',true,'LabelSource','foldernames');
[imdsTrain,imdsValidation] = splitEachLabel(imds,0.7); %70% for train 30% for test
net=resnet18; % for the first time,you have to download the package from Add-on explorer
%Replace Final Layers
numClasses = numel(categories(imdsTrain.Labels));
lgraph = layerGraph(net);
newFCLayer = fullyConnectedLayer(numClasses,'Name','new_fc','WeightLearnRateFactor',10,'BiasLearnRateFactor',10);
lgraph = replaceLayer(lgraph,'fc1000' ,newFCLayer);
newClassLayer = classificationLayer('Name','new_classoutput');
lgraph = replaceLayer(lgraph,'ClassificationLayer_predictions',newClassLayer);
%Train Network
inputSize = net.Layers(1).InputSize;
imageAugmenter = imageDataAugmenter( ...
'RandRotation',[-5,5], ...
'RandXTranslation',[-5 5], ...
'RandYTranslation',[-5 5]);
augimdsTrain = augmentedImageDatastore(inputSize(1:2),imdsTrain,'DataAugmentation',imageAugmenter);
augimdsValidation = augmentedImageDatastore(inputSize(1:2),imdsValidation);
options = trainingOptions('sgdm', ...
'MiniBatchSize',10, ...
'MaxEpochs',20, ...
'InitialLearnRate',1e-3, ...
'Shuffle','every-epoch', ...
'ValidationData',augimdsValidation, ...
'ValidationFrequency',5, ...
'Verbose',false, ...
'Plots','training-progress');
trainedNet = trainNetwork(augimdsTrain,lgraph,options);
YPred = classify(trainedNet,augimdsValidation);
accuracy = mean(YPred == imdsValidation.Labels)
C = confusionmat(imdsValidation.Labels,YPred)
cm = confusionchart(imdsValidation.Labels,YPred);
cm.Title = 'Confusion Matrix for Validation Data';
cm.ColumnSummary = 'column-normalized';
cm.RowSummary = 'row-normalized';

채택된 답변

Sugandhi
Sugandhi 2023년 4월 28일
편집: Sugandhi 2023년 4월 28일
Hi Muhammad,
I understand that you are using CNN architecture resnet18 with transfer learning for classifications. Overfitting is happened after trainging and testing the model.
Based on the code you provided, here are some workarounds to address the issue of overfitting in your ResNet-18 CNN model:
  1. Increase the amount of data augmentation: Data augmentation is a technique that artificially increases the size of your dataset by applying random transformations to the images during training. It helps in introducing variability in the data, making the model more robust to overfitting. You can try increasing the amount of data augmentation by adding more random transformations such as horizontal flipping, vertical flipping, and changing brightness/contrast.
  2. Use dropout regularization: Dropout is a regularization technique that randomly sets a fraction of the input units to 0 at each update during training, which helps in preventing the model from relying too heavily on certain features and encourages it to learn more generalized representations. You can add a dropout layer after the fully connected layer in your model by using the dropoutLayer function from MATLAB's Deep Learning Toolbox.
  3. Reduce the learning rate: A high learning rate can cause the model to overshoot the optimal weights during training, leading to overfitting. You can try reducing the initial learning rate in your trainingOptions function, for example, by setting it to a lower value such as 1e-4 or 1e-5.
  4. Use early stopping: Early stopping is a technique that monitors the validation loss during training and stops the training process if the validation loss starts to increase, indicating overfitting. You can add the EarlyStopping option in your trainingOptions function and set it to a reasonable value, such as 5 or 10, to stop training early if needed.
  5. Add more training data: Overfitting can occur when the model is not exposed to enough diverse training data. You can consider increasing the size of your training dataset by collecting more data, or by using data augmentation techniques to generate synthetic data.
  6. Try using a smaller model: ResNet-18 is a relatively deep model with a large number of parameters, which can make it more prone to overfitting, especially when the training dataset is small. You can try using a smaller CNN architecture, such as ResNet-9 or a custom architecture with fewer layers, to see if it helps in reducing overfitting.
  7. Regularize the fully connected layers: You can add weight regularization techniques, such as L1 or L2 regularization, to the fully connected layers in your model to prevent overfitting. You can use the fullyConnectedLayer function's WeightRegularization and BiasRegularization options to specify the type and strength of regularization to apply.
Implementing these changes can help in reducing overfitting in your ResNet-18 model and improving its generalization performance.
  댓글 수: 1
Muhammad
Muhammad 2023년 4월 30일
Can you please help me to modify my code to add a dropout regularization and early stopping?

댓글을 달려면 로그인하십시오.

추가 답변 (0개)

카테고리

Help CenterFile Exchange에서 Image Data Workflows에 대해 자세히 알아보기

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!

Translated by