How can I validate CNN after training?
이전 댓글 표시
I have 4 samples, each sample contains about 51,000 images. I train the network but each training ends with a suden fall of the validation accuracy. I have tried to increase the minibatch size, it helped a little but now I cannot increase further because I'm running out of memory
After reading some few helps online I think I can train the network without validation data but how do I validate after training?
Thank you in advance.
% reset(gpuDevice(1));
cz1 = fullfile('v1');
imds = imageDatastore(cz1,'LabelSource','none','IncludeSubfolders',true,'FileExtensions','.mat','ReadFcn',@(filename)customreader(filename));
tbl = countEachLabel(imds)
[trainingSet,validationSet, testSet] = splitEachLabel(imds,0.8, 0.1, 0.1 ...
...%399,49,49 ...
...%100,10,10 ...
,'randomized');
layers = [
imageInputLayer([480 640 1],'Normalization', ...
'none','Name','input')
%layer 1
convolution2dLayer([3 3],64,'Stride',[1 1],'Padding',1)
batchNormalizationLayer('Name','BN1')
reluLayer('Name','relu1')
maxPooling2dLayer([3 3],'Stride',2,'Name','MP1')
%layer 4
convolution2dLayer([3 3],64,'Stride',[1 1],'Padding',1,'Name','conv4')
batchNormalizationLayer('Name','BN4')
reluLayer('Name','relu4')
%layer 5
convolution2dLayer([3 3],64,'Stride',[2 2],'Padding',1,'Name','conv5')
additionLayer(2,'Name','add2')
batchNormalizationLayer('Name','BN5')
fullyConnectedLayer(512,'Name','fc1')
batchNormalizationLayer('Name','BN7')
reluLayer('Name','relu7')
%warstwa 7
fullyConnectedLayer(256,'Name','fc2')
reluLayer('Name','relu8')
%warstwa 8
fullyConnectedLayer(4,'Name','fc3')
softmaxLayer('Name','softmax')
classificationLayer('Name','classif')
]
lgraph = layerGraph(layers);
skipConv1 = convolution2dLayer(1,64,'Stride',2,'Name','skipConv1');
lgraph = addLayers(lgraph,skipConv1);
lgraph = connectLayers(lgraph,'MP1','skipConv1');
lgraph = connectLayers(lgraph,'skipConv1','add1/in2')
skipConv2 = convolution2dLayer(1,64,'Stride',2,'Name','skipConv2');
lgraph = addLayers(lgraph,skipConv2);
lgraph = connectLayers(lgraph,'MP3','skipConv2');
lgraph = connectLayers(lgraph,'skipConv2','add2/in2')
options = trainingOptions('sgdm', ...
'InitialLearnRate',0.001, ...
'MaxEpochs',10, ...
'Shuffle','every-epoch', ...
'ValidationData',validationSet, ...
'ValidationFrequency',4000, ...
'ValidationPatience',1, ...
'LearnRateSchedule','piecewise', ...
'LearnRateDropPeriod',1,...
'Verbose',false, ...
'Plots','training-progress', ...
'MiniBatchSize', 20, ...
'CheckpointPath', 't7g');
net = trainNetwork(trainingSet,lgraph,options);
[YPred,scores] = classify(net,testSet,'MiniBatchSize',20);
[S,I] = maxk(scores',5);
YValidation = testSet.Labels;
accuracy = sum(YPred == YValidation)/numel(YValidation)
top5 = sum(sum(tbl.Label(I)' == YValidation))/numel(YValidation)
function data = customreader(filename)
load(filename,'frame');
end

채택된 답변
추가 답변 (0개)
카테고리
도움말 센터 및 File Exchange에서 Get Started with Statistics and Machine Learning Toolbox에 대해 자세히 알아보기
Community Treasure Hunt
Find the treasures in MATLAB Central and discover how the community can help you!
Start Hunting!