Output Function to Save Net on Every Validation
이전 댓글 표시
I'm curious if it's possible to define an output function to spit out the current state of the network while training by using an output function to put that current net into a structure in the same way I have it defined to spit out [net,tr] = trainNetwork() when it finishes, but does so during training.
I can't use checkpoints because I am using an ADAM solver for my network.
1: Net,TR
2: Net, TR
3: Net, TR
4: Net, TR
etc.
댓글 수: 1
Ameer Hamza
2020년 5월 6일
It seems that the outputFcn cannot save the network itself after each iteration. Is saving just the state of network training enough?
답변 (1개)
Ameer Hamza
2020년 5월 6일
편집: Ameer Hamza
2020년 5월 6일
If you just want to save the training states, then try the following example. It is adapted from this example: https://www.mathworks.com/help/releases/R2020a/deeplearning/ref/trainingoptions.html#bvniuj4
[XTrain,YTrain] = digitTrain4DArrayData;
idx = randperm(size(XTrain,4),1000);
XValidation = XTrain(:,:,:,idx);
XTrain(:,:,:,idx) = [];
YValidation = YTrain(idx);
YTrain(idx) = [];
layers = [
imageInputLayer([28 28 1])
convolution2dLayer(3,8,'Padding','same')
batchNormalizationLayer
reluLayer
maxPooling2dLayer(2,'Stride',2)
convolution2dLayer(3,16,'Padding','same')
batchNormalizationLayer
reluLayer
maxPooling2dLayer(2,'Stride',2)
convolution2dLayer(3,32,'Padding','same')
batchNormalizationLayer
reluLayer
fullyConnectedLayer(10)
softmaxLayer
classificationLayer];
options = trainingOptions('sgdm', ...
'MaxEpochs',8, ...
'ValidationData',{XValidation,YValidation}, ...
'ValidationFrequency',30, ...
'Verbose',false, ...
'Plots','training-progress', ...
'OutputFcn', @outFcn);
global training_state
training_state = [];
net = trainNetwork(XTrain,YTrain,layers,options);
function stop = outFcn(info)
global training_state
training_state = [training_state info];
stop = false;
end
Use of the global variable can be avoided if you define your own handle class and pass it to the outFcn. However, if you are fine with the use of global, then it shouldn't be an issue.
댓글 수: 4
Grant Anderson
2020년 5월 11일
Ameer Hamza
2020년 5월 11일
The information captured by training_state cannot be used to rebuild the network. It only gives info like training accuracy, validation accuracy, etc. I don't think MATLAB provides a way to capture the entire network state during training.
Also, the global variable only needs to be defined in the base workspace and the workspace of the function it is being used, i.e., outFcn. It does not need to be defined anywhere else. Can you just show the general structure of your code, and where are you trying to define the global variables?
Grant Anderson
2020년 5월 11일
편집: Grant Anderson
2020년 5월 11일
Ameer Hamza
2020년 5월 12일
If you want to check the value of training_state in the base workspace after the execution of your function, then you should also run the following line in the command window before calling your function.
global training_state
카테고리
도움말 센터 및 File Exchange에서 Deep Learning Toolbox에 대해 자세히 알아보기
Community Treasure Hunt
Find the treasures in MATLAB Central and discover how the community can help you!
Start Hunting!