How to use Vanilla SGD solver in training options ?

조회 수: 9 (최근 30일)
Mira mosad
Mira mosad 2022년 12월 20일
답변: Meet 2024년 9월 12일
when i used Vanilla SGD instead of adam solver the code has error : invalid solver name .
how can i use Vanilla SGD instead of adam solver ?
this is my code for traning options part :
options = trainingOptions('sgdm', ...
'MaxEpochs',20,...
'InitialLearnRate',1e-4, ...
'Verbose',false, ...
'Plots','training-progress');

답변 (1개)

Meet
Meet 2024년 9월 12일
Hi Mira,
The option for vanilla SGD is not available as a pre-built solver in the “trainingOptions” function. However, you can define a custom SGD solver and training loop according to your preferences.
Below is the code for defining a custom SGD solver and training loop:
Custom SGD Function:
function parameters = sgdStep(parameters,gradients,learnRate)
parameters = parameters - learnRate .* gradients;
end
Custom Training Loop:
epoch = 0;
iteration = 0;
% Loop over epochs.
while epoch < numEpochs && ~monitor.Stop
epoch = epoch + 1;
% Shuffle data.
shuffle(mbq);
% Loop over mini-batches.
while hasdata(mbq) && ~monitor.Stop
iteration = iteration + 1;
% Read mini-batch of data.
[X,T] = next(mbq);
% Evaluate the model gradients, state, and loss using dlfeval and the
% modelLoss function and update the network state.
[loss,gradients,state] = dlfeval(@modelLoss,net,X,T);
net.State = state;
% Update the network parameters using SGD.
updateFcn = @(parameters,gradients) sgdStep(parameters,gradients,learnRate);
net = dlupdate(updateFcn,net,gradients);
% Update the training progress monitor.
recordMetrics(monitor,iteration,Loss=loss);
updateInfo(monitor,Epoch=epoch);
monitor.Progress = 100 * iteration/numIterations;
end
end
You can refer to the resource below for more information:

카테고리

Help CenterFile Exchange에서 Chemistry에 대해 자세히 알아보기

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!

Translated by