trainnet difference between TrainingLoss and manually computed MSE loss

I train a dlNetwork using trainnet and a custom OutputFcn that loads the network at a given frequency using Checkpoint.
This is the how the TrainingOptions are defined:
options = trainingOptions(algo, ...
'MaxEpochs', epochs, ...
'MiniBatchSize', 1024, ...
'InitialLearnRate', learnrate,...
'Verbose',false,...
'CheckpointPath',checkdir,...
'CheckpointFrequencyUnit', 'iteration', ...
'CheckpointFrequency', check_freq,...
OutputFcn=@(info)updatePlotAndStopTraining(info,lines, checkdir, check_freq, XTest, YTrain, XTrain));
This is the custom OutputFcn where I also manually calculate the mse:
function stop = updatePlotAndStopTraining(info,lines, directory__, checkFreq, XTest, YTrain, XTrain)
global msee
iteration = info.Iteration;
trainingLoss = info.TrainingLoss;
if (~isempty(iteration)) && (mod(iteration,checkFreq)==0) && (iteration ~= 0)
d = dir(fullfile(directory__, '*.mat'));
dates = {d.date};
files = {d.name};
[~, idx] = sort(datenum(dates));
latest_file_name = files{idx(end)};
checknet = load(fullfile(directory__, latest_file_name));
msee = (mse(predict(checknet.net, XTrain.'), YTrain));
end
if iteration<checkFreq
if isvalid(lines.distanceToBase)
addpoints(lines.trainingLossLine,iteration,1.0)
addpoints(lines.mse,iteration,0.0)
end
elseif ~isempty(trainingLoss)
if isvalid(lines.distanceToBase)
addpoints(lines.trainingLossLine,iteration,trainingLoss)
addpoints(lines.mse,iteration,msee)
end
end
stop = false;
end
This is the training call:
[finalnet,info] = trainnet(XTrain.', YTrain.', resetNet,'mse', options);
I would expect mse and training loss to be very close only differing due to the TrainingLoss being normalized, but they are going opposite direction. While my manually computed mse suggested the model is not converging, TrainingLoss shows some convergence...

답변 (1개)

Hi,
You computed MSE using a checkpointed network, which may lag behind the current training state. predict(checknet.net, XTrain.') uses entire training set, while TrainingLoss is per mini-batch and normalized. Data orientation and timing mismatch cause misleading trends.
These approaches can be used to sort it:
  1. Use the current network (info.TrainedNetwork) instead of loading checkpoints.
  2. Compute MSE on the same mini-batch or validation set for consistency:
preds = predict(info.TrainedNetwork, XTrain.');
msee = mse(preds, YTrain);
This aligns your metric with training progress.
Hope it helps!

카테고리

도움말 센터File Exchange에서 Deep Learning Toolbox에 대해 자세히 알아보기

제품

릴리스

R2025b

질문:

SYt
2025년 11월 17일

답변:

2025년 11월 20일

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!

Translated by