I have 4 samples, each sample contains about 51,000 images. I train the network but each training ends with a suden fall of the validation accuracy. I have tried to increase the minibatch size, it helped a little but now I cannot increase further because I'm running out of memory
After reading some few helps online I think I can train the network without validation data but how do I validate after training?
Thank you in advance.
% reset(gpuDevice(1));
cz1 = fullfile('v1');
imds = imageDatastore(cz1,'LabelSource','none','IncludeSubfolders',true,'FileExtensions','.mat','ReadFcn',@(filename)customreader(filename));
tbl = countEachLabel(imds)
[trainingSet,validationSet, testSet] = splitEachLabel(imds,0.8, 0.1, 0.1 ...
...%399,49,49 ...
...%100,10,10 ...
,'randomized');
layers = [
imageInputLayer([480 640 1],'Normalization', ...
'none','Name','input')
%layer 1
convolution2dLayer([3 3],64,'Stride',[1 1],'Padding',1)
batchNormalizationLayer('Name','BN1')
reluLayer('Name','relu1')
maxPooling2dLayer([3 3],'Stride',2,'Name','MP1')
%layer 4
convolution2dLayer([3 3],64,'Stride',[1 1],'Padding',1,'Name','conv4')
batchNormalizationLayer('Name','BN4')
reluLayer('Name','relu4')
%layer 5
convolution2dLayer([3 3],64,'Stride',[2 2],'Padding',1,'Name','conv5')
additionLayer(2,'Name','add2')
batchNormalizationLayer('Name','BN5')
fullyConnectedLayer(512,'Name','fc1')
batchNormalizationLayer('Name','BN7')
reluLayer('Name','relu7')
%warstwa 7
fullyConnectedLayer(256,'Name','fc2')
reluLayer('Name','relu8')
%warstwa 8
fullyConnectedLayer(4,'Name','fc3')
softmaxLayer('Name','softmax')
classificationLayer('Name','classif')
]
lgraph = layerGraph(layers);
skipConv1 = convolution2dLayer(1,64,'Stride',2,'Name','skipConv1');
lgraph = addLayers(lgraph,skipConv1);
lgraph = connectLayers(lgraph,'MP1','skipConv1');
lgraph = connectLayers(lgraph,'skipConv1','add1/in2')
skipConv2 = convolution2dLayer(1,64,'Stride',2,'Name','skipConv2');
lgraph = addLayers(lgraph,skipConv2);
lgraph = connectLayers(lgraph,'MP3','skipConv2');
lgraph = connectLayers(lgraph,'skipConv2','add2/in2')
options = trainingOptions('sgdm', ...
'InitialLearnRate',0.001, ...
'MaxEpochs',10, ...
'Shuffle','every-epoch', ...
'ValidationData',validationSet, ...
'ValidationFrequency',4000, ...
'ValidationPatience',1, ...
'LearnRateSchedule','piecewise', ...
'LearnRateDropPeriod',1,...
'Verbose',false, ...
'Plots','training-progress', ...
'MiniBatchSize', 20, ...
'CheckpointPath', 't7g');
net = trainNetwork(trainingSet,lgraph,options);
[YPred,scores] = classify(net,testSet,'MiniBatchSize',20);
[S,I] = maxk(scores',5);
YValidation = testSet.Labels;
accuracy = sum(YPred == YValidation)/numel(YValidation)
top5 = sum(sum(tbl.Label(I)' == YValidation))/numel(YValidation)
function data = customreader(filename)
load(filename,'frame');
end

 채택된 답변

Joss Knight
Joss Knight 2022년 3월 16일

0 개 추천

The easiest way to validate after training for classification is to do exactly what you do in your example code to check the accuracy of your test set, but with your validation set. To compute the cross-entropy loss rather than accuracy you might need to implement the crossentropy function yourself. You could just pass your validation data in instead of training data and train for a few iterations to get some numbers.
The simplest way to get rid of this unusual fall in accuracy at the final iteration is to use moving average statistics for your batch normalization layers. Set the BatchNormalizationStatistics training option to 'moving'. (See the documentation.)

댓글 수: 4

Kwasi
Kwasi 2022년 3월 17일
Thank you very much Joss Knight for your help.
Kwasi
Kwasi 2022년 3월 20일
Hello Joss,
Please is the BatchNormalizationStatistics training option available in Matlab 2019b?
Because it is displayed on this link as available
But my code reports that it is not.
Error using trainingOptions (line 285)
'BatchNormalizationStatistics' is not an option for solver 'sgdm'.
Thank you.
Joss Knight
Joss Knight 2022년 3월 21일
No, it was introduced in R2021a.
Kwasi
Kwasi 2022년 3월 21일
Thank you for responding and the answer.

댓글을 달려면 로그인하십시오.

추가 답변 (0개)

카테고리

도움말 센터File Exchange에서 Get Started with Statistics and Machine Learning Toolbox에 대해 자세히 알아보기

제품

릴리스

R2019b

질문:

2022년 3월 14일

댓글:

2022년 3월 21일

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!

Translated by