How to use Vanilla SGD solver in training options ?

조회 수: 2 (최근 30일)
Mira mosad
Mira mosad 2022년 12월 20일
답변: Meet 2024년 9월 12일
when i used Vanilla SGD instead of adam solver the code has error : invalid solver name .
how can i use Vanilla SGD instead of adam solver ?
this is my code for traning options part :
options = trainingOptions('sgdm', ...
'MaxEpochs',20,...
'InitialLearnRate',1e-4, ...
'Verbose',false, ...
'Plots','training-progress');

답변 (1개)

Meet
Meet 2024년 9월 12일
Hi Mira,
The option for vanilla SGD is not available as a pre-built solver in the “trainingOptions” function. However, you can define a custom SGD solver and training loop according to your preferences.
Below is the code for defining a custom SGD solver and training loop:
Custom SGD Function:
function parameters = sgdStep(parameters,gradients,learnRate)
parameters = parameters - learnRate .* gradients;
end
Custom Training Loop:
epoch = 0;
iteration = 0;
% Loop over epochs.
while epoch < numEpochs && ~monitor.Stop
epoch = epoch + 1;
% Shuffle data.
shuffle(mbq);
% Loop over mini-batches.
while hasdata(mbq) && ~monitor.Stop
iteration = iteration + 1;
% Read mini-batch of data.
[X,T] = next(mbq);
% Evaluate the model gradients, state, and loss using dlfeval and the
% modelLoss function and update the network state.
[loss,gradients,state] = dlfeval(@modelLoss,net,X,T);
net.State = state;
% Update the network parameters using SGD.
updateFcn = @(parameters,gradients) sgdStep(parameters,gradients,learnRate);
net = dlupdate(updateFcn,net,gradients);
% Update the training progress monitor.
recordMetrics(monitor,iteration,Loss=loss);
updateInfo(monitor,Epoch=epoch);
monitor.Progress = 100 * iteration/numIterations;
end
end
You can refer to the resource below for more information:

카테고리

Help CenterFile Exchange에서 Image Data Workflows에 대해 자세히 알아보기

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!

Translated by