Poor performance of trainNetwork() function as compared to train()

조회 수: 7 (최근 30일)
Ivan Rodionov
Ivan Rodionov 2025년 2월 8일
편집: Matt J 2025년 2월 9일
Hello, I am having issues with training a neural network using trainNetwork() as compared to train() and I am stumped. I tried to set up an identical network architecture, the first using fitnet() for train() function and the second using the toolbox for trainNetwork. The train() function converges rapidly and gets to a good solution, while I am currently unable to get the trainNetwork() to converge decently. What am I doing wrong?
Code 1: Working with train()
% Select if new or old network
if useExisting == false
fprintf("Option A: Train New Model\n");
net = fitnet(hidden_layer_size); % Define network
else
fprintf("Option B: Load and Fine-tune Existing Model\n");
net = loadedData.net; % Load the existing neural network
end
% Set the common network training parameters
net.trainFcn = 'trainscg';
net.trainParam.epochs = 250E3;
net.trainParam.goal = 0;
net.trainParam.max_fail = 6;
% Divide data for training and testing (80:20 split)
net.divideParam.trainRatio = 0.8;
net.divideParam.testRatio = 0.1;
net.divideParam.valRatio = 0.1; % Disable validation as we're focusing on training and testing
% Train the neural network
net = train(net, input, target, 'usegpu','yes'); % Training with shuffled input and target
Code 2: Not working
options = trainingOptions('rmsprop', ...
'MaxEpochs', 250E3, ... % Keep epochs reasonable
'MiniBatchSize', 32, ... % Large batch for stable updates
'InitialLearnRate', 1E-3, ... % Lower LR since RMSprop adapts per-parameter
'SquaredGradientDecayFactor', 0.85, ... % Default is 0.99; try 0.9 for faster adaptation
'Shuffle', 'every-epoch', ... % Keep it since data has overlap
'ValidationData', {input(:, 1:round(0.2 * end)), target(:, 1:round(0.2 * end))}, ...
'ValidationFrequency', 50, ... % Check validation every 50 mini-batches
'Verbose', true, ...
'Plots', 'training-progress', ...
'ValidationPatience', 12, ... % More patience for slow convergence
'ExecutionEnvironment', 'gpu'); % Use GPU for speed
% Adjust layers for 1D data
layers = [
sequenceInputLayer(chunk_size, 'Name', 'input') % Adjusted for 1D data
fullyConnectedLayer(hidden_layer_size, 'Name', 'fc1')
tanhLayer
fullyConnectedLayer(chunk_size, 'Name', 'output') % Adjust output size if needed
regressionLayer('Name', 'regression')];
% Select if new or old network
if useExisting
fprintf("Option B: Load and Fine-tune Existing Model\n");
net = trainNetwork(input, target, net.Layers, options);
else
fprintf("Option A: Train New Model\n");
net = trainNetwork(input, target, layers, options);
end
I have tried playing with the training parameters in the trainNetwork and this is the best I was able to set. Unfortunately, the performance is dismal compared to the train().
  댓글 수: 3
Ivan Rodionov
Ivan Rodionov 2025년 2월 8일
편집: Ivan Rodionov 2025년 2월 8일
@Matt J Hello Matt and thank you for your reply. I guess in some sense you are right, however if I am understanding the code correctly, it should not be the case because both codes are fundamentally doing the same thing? A single hidden layer slightly wide neural network with identical tanh activation function. Be it made via fitnet() or layers, it should fundamentally not work with one and not at all with the other unless something is broken and or I am misunderstanding what is going on behind the scenes?
EDIT:
If you are curious, I can gladly provide the datasets, it is a distorted signal.
Matt J
Matt J 2025년 2월 9일
편집: Matt J 2025년 2월 9일
But the algorithm used by trainscg() is different, and has fewer tuning parameters than rmsprop. We don't know how performance might improve if you changed the InitialLearnRate, MiniBatchSize, and other rmsprop parameters. You might try Adam instead of RmsProp. I've heard it is more robust.

댓글을 달려면 로그인하십시오.

답변 (1개)

Matt J
Matt J 2025년 2월 9일
편집: Matt J 2025년 2월 9일
I have tried playing with the training parameters in the trainNetwork and this is the best I was able to set
You can try using the Experiment Manager to explore the hyperparameter space more systematically,

카테고리

Help CenterFile Exchange에서 Sequence and Numeric Feature Data Workflows에 대해 자세히 알아보기

제품


릴리스

R2023b

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!

Translated by