필터 지우기
필터 지우기

Why are the gradients not backpropagating into the encoder in this custom loop?

조회 수: 7 (최근 30일)
I am building a convolutional autoencoder using a custom training loop. When I attempt to reconstruct the images, the network's output degenerates to guessing the same incorrect value for all inputs. However, training the autoencoder in a single stack with the trainnet function works fine, indicating that the gradient updates are unable to bridge the bottleneck layer in the custom training loop. Unfortunately, I need to use the custom training loop for a different task and am prohibited from using TensorFlow or PyTorch.
What is the syntax to ensure that the encoder is able to update based on the decoder's reconstruction performance?
%% Functional 'trainnet' loop
clear
close all
clc
% Get handwritten digit data
xTrain = digitTrain4DArrayData;
xTest = digitTest4DArrayData;
% Check that all pixel values are min-max scaled
assert(max(xTrain(:)) == 1); assert(min(xTrain(:)) == 0);
assert(max(xTest(:)) == 1); assert(min(xTest(:)) == 0);
imageSize = [28 28 1];
%% Layer definitions
% Latent projection
projectionSize = [7 7 64];
numInputChannels = imageSize(3);
% Decoder
aeLayers = [
imageInputLayer(imageSize)
convolution2dLayer(3,32,Padding="same",Stride=2)
reluLayer
convolution2dLayer(3,64,Padding="same",Stride=2)
reluLayer
transposedConv2dLayer(3,64,Cropping="same",Stride=2)
reluLayer
transposedConv2dLayer(3,32,Cropping="same",Stride=2)
reluLayer
transposedConv2dLayer(3,numInputChannels,Cropping="same")
sigmoidLayer(Name='Output')
];
autoencoder = dlnetwork(aeLayers);
%% Training Parameters
numEpochs = 150;
miniBatchSize = 25;
learnRate = 1e-3;
options = trainingOptions("adam", ...
InitialLearnRate=learnRate,...
MaxEpochs=30, ...
Plots="training-progress", ...
TargetDataFormats="SSCB", ...
InputDataFormats="SSCB", ...
MiniBatchSize=miniBatchSize, ...
OutputNetwork="last-iteration", ...
Shuffle="every-epoch");
autoencoder = trainnet(dlarray(xTrain, 'SSCB'),dlarray(xTrain, 'SSCB'), ...
autoencoder, 'mse', options);
%% Testing
YTest = predict(autoencoder, dlarray(xTest, 'SSCB'));
indices = randi(size(xTest, 4), [1, size(xTest, 4)]); % Shuffle YTest & xTest
xTest = xTest(:,:,:,indices); YTest = YTest(:,:,:,indices);
% Display test images
numImages = 64;
figure
subplot(1,2,1)
preds = extractdata(YTest(:,:,:,1:numImages));
I = imtile(preds);
imshow(I)
title("Reconstructed Images")
subplot(1,2,2)
orgs = xTest(:,:,:,1:numImages);
I = imtile(orgs);
imshow(I)
title("Original Images")
%% Nonfunctional Custom Training Loop
clear
close all
clc
% Get handwritten digit data
xTrain = digitTrain4DArrayData;
xTest = digitTest4DArrayData;
% Check that all pixel values are min-max scaled
assert(max(xTrain(:)) == 1); assert(min(xTrain(:)) == 0);
assert(max(xTest(:)) == 1); assert(min(xTest(:)) == 0);
imageSize = [28 28 1];
%% Layer definitions
% Encoder
layersE = [
imageInputLayer(imageSize)
convolution2dLayer(3,32,Padding="same",Stride=2)
reluLayer
convolution2dLayer(3,64,Padding="same",Stride=2)
reluLayer];
% Latent projection
projectionSize = [7 7 64];
numInputChannels = imageSize(3);
% Decoder
layersD = [
imageInputLayer(projectionSize)
transposedConv2dLayer(3,64,Cropping="same",Stride=2)
reluLayer
transposedConv2dLayer(3,32,Cropping="same",Stride=2)
reluLayer
transposedConv2dLayer(3,numInputChannels,Cropping="same")
sigmoidLayer(Name='Output')
];
netE = dlnetwork(layersE);
netD = dlnetwork(layersD);
%% Training Parameters
numEpochs = 150;
miniBatchSize = 25;
learnRate = 1e-3;
% Create training minibatchqueue
dsTrain = arrayDatastore(xTrain,IterationDimension=4);
numOutputs = 1;
mbq = minibatchqueue(dsTrain,numOutputs, ...
MiniBatchSize = miniBatchSize, ...
MiniBatchFormat="SSCB", ...
MiniBatchFcn=@preprocessMiniBatch,...
PartialMiniBatch="return");
%Initialize the parameters for the Adam solver.
trailingAvgE = [];
trailingAvgSqE = [];
trailingAvgD = [];
trailingAvgSqD = [];
%Calculate the total number of iterations for the training progress monitor
numIterationsPerEpoch = ceil(size(xTrain, 4) / miniBatchSize);
numIterations = numEpochs * numIterationsPerEpoch;
epoch = 0;
iteration = 0;
%Initialize the training progress monitor.
monitor = trainingProgressMonitor( ...
Metrics="TrainingLoss", ...
Info=["Epoch", "LearningRate"], ...
XLabel="Iteration");
%% Training
while epoch < numEpochs && ~monitor.Stop
epoch = epoch + 1;
% Shuffle data.
shuffle(mbq);
% Loop over mini-batches.
while hasdata(mbq) && ~monitor.Stop
% Assess validation criterion
iteration = iteration + 1;
% Read mini-batch of data.
X = next(mbq);
% Evaluate loss and gradients.
[loss,gradientsE,gradientsD] = dlfeval(@modelLoss,netE,netD,X);
% Update learnable parameters.
[netE,trailingAvgE,trailingAvgSqE] = adamupdate(netE, ...
gradientsE,trailingAvgE,trailingAvgSqE,iteration,learnRate);
[netD, trailingAvgD, trailingAvgSqD] = adamupdate(netD, ...
gradientsD,trailingAvgD,trailingAvgSqD,iteration,learnRate);
updateInfo(monitor, ...
LearningRate=learnRate, ...
Epoch=string(epoch) + " of " + string(numEpochs));
recordMetrics(monitor,iteration, ...
TrainingLoss=loss);
monitor.Progress = 100*iteration/numIterations;
end
end
%% Testing
dsTest = arrayDatastore(xTest,IterationDimension=4);
numOutputs = 1;
ntest = size(xTest, 4);
indices = randi(ntest,[1,ntest]);
xTest = xTest(:,:,:,indices);% Shuffle test data
mbqTest = minibatchqueue(dsTest,numOutputs, ...
MiniBatchSize = miniBatchSize, ...
MiniBatchFcn=@preprocessMiniBatch, ...
MiniBatchFormat="SSCB");
YTest = modelPredictions(netE,netD,mbqTest);
% Display test images
numImages = 64;
figure
subplot(1,2,1)
preds = YTest(:,:,:,1:numImages);
I = imtile(preds);
imshow(I)
title("Reconstructed Images")
subplot(1,2,2)
orgs = xTest(:,:,:,1:numImages);
I = imtile(orgs);
imshow(I)
title("Original Images")
%% Functions
function [loss,gradientsE,gradientsD] = modelLoss(netE,netD,X)
% Forward through encoder.
Z = forward(netE,X);
% Forward through decoder.
Xrecon = forward(netD,Z);
% Calculate loss and gradients.
loss = regularizedLoss(Xrecon,X);
[gradientsE,gradientsD] = dlgradient(loss,netE.Learnables,netD.Learnables);
end
function loss = regularizedLoss(Xrecon,X)
% Image Reconstruction loss.
reconstructionLoss = l2loss(Xrecon, X, 'NormalizationFactor','all-elements');
% Combined loss.
loss = reconstructionLoss;
end
function Xrecon = modelPredictions(netE,netD,mbq)
Xrecon = [];
shuffle(mbq)
% Loop over mini-batches.
while hasdata(mbq)
X = next(mbq);
% Pass through encoder
Z = predict(netE,X);
% Pass through decoder to get reconstructed images
XGenerated = predict(netD,Z);
% Extract and concatenate predictions.
Xrecon = cat(4,Xrecon,extractdata(XGenerated));
end
end
function X = preprocessMiniBatch(Xcell)
% Concatenate.
X = cat(4,Xcell{:});
end
  댓글 수: 2
Matt J
Matt J 2024년 7월 16일
편집: Matt J 2024년 7월 16일
training the autoencoder in a single stack with the trainnet function works fine, indicating that the gradient updates are unable to bridge the bottleneck layer in the custom training loop.
I don't see why gradient back propagation failure is the only possible culprit. However, it should be something that is easy to test. You can implement a second version of your modelLoss that takes a single stack as input. Then you can run dlfeval on both versions and see if they return the same gradients (within floating point differences).
Joseph Conroy
Joseph Conroy 2024년 7월 16일
Hm. The stacked autoencoder trains exactly the same as the separate encoder/decoders when run through the custom training loop. It must be something I have done in the training loop then.

댓글을 달려면 로그인하십시오.

채택된 답변

Matt J
Matt J 2024년 7월 16일
편집: Matt J 2024년 7월 16일
In terms of what may be different from trainnet, I don't see any regularization in your customized loop. You have a function called regularizedLoss(), but it doesn't seem to evaluate any regularization terms or apply any regularization hyperparameters.
Aside from that, I wonder where the parameter initialization is happening. Presumably it is in adamupdate(), but since you call adamupdate separately on netE and netD, I am not sure how that might be affecting the initialization as compared to when trainnet is used on the entire end-to-end network.
  댓글 수: 19
Joseph Conroy
Joseph Conroy 2024년 7월 18일
Marvelous! Thank you.
I know you've given me a great deal of your time, but if you could manage, I would like to hear your thoughts on why adding L2 regularization to the weights induces a solution that is a better fit to the training data, not just a more generalizable solution. My understanding of L2 regularization is that it adds a Gaussian prior about 0 for the weights on a given layer, pulling the solution closer towards the W = 0 origin. I can easily see how this favors more generalizable solutions, but I am not certain why adding this allows the solver to escape the local minimum of guessing the average.
My first guess would be that guessing the average reduces the error signal from the predictions, allowing the -const*w term to dominate in the gradient updates, which subsequently pulls the solution out of its degenerate behavior. This vague notion suggests this behavior is not general, but rather is unique to a dataset wherein the optimal solution lies closer to the weight-space origin than the average guess.
Matt J
Matt J 2024년 7월 18일
편집: Matt J 2024년 7월 18일
This vague notion suggests this behavior is not general, but rather is unique to a dataset wherein the optimal solution lies closer to the weight-space origin than the average guess.
I don't know how general it is, but deep learning data loss functions do tend to have plateaus and local minima at large values of the weights, because with large weights it is easy for the ReLUs to saturate.

댓글을 달려면 로그인하십시오.

추가 답변 (0개)

카테고리

Help CenterFile Exchange에서 Custom Training Loops에 대해 자세히 알아보기

제품


릴리스

R2023b

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!

Translated by