Gradients not recorded for a dlnetwork VAE

조회 수: 5 (최근 30일)
SIMON
SIMON 2025년 8월 22일
댓글: Torsten 2025년 8월 23일
I have created a VAE using dlnetwork - an encoder, decoder and a classifier. The loss function will not record gradients for the update via dlgradients. Can you have a look and discover what is the cause.
% === Setup ===
numEpochs = 5;
miniBatchSize = 32;
learnRate = 1e-3;
latentDim = 32;
hiddenUnits = 64;
inputSize = size(XTrain{1}, 1); % Number of features
sequenceLength = size(XTrain{1}, 2); % Time steps
% === Encoder ===
encoderLayers = [
sequenceInputLayer(inputSize, 'Name', 'input')
lstmLayer(64, 'OutputMode', 'last', 'Name', 'lstm_enc')
fullyConnectedLayer(latentDim, 'Name', 'fc_mu')
fullyConnectedLayer(latentDim, 'Name', 'fc_logvar')
];
encoderNet = dlnetwork(layerGraph(encoderLayers));
% === Decoder ===
decoderLayers = [
sequenceInputLayer(latentDim, 'Name', 'latent_input')
fullyConnectedLayer(64, 'Name', 'fc_latent')
lstmLayer(64, 'OutputMode', 'sequence', 'Name', 'lstm_dec')
fullyConnectedLayer(inputSize, 'Name', 'fc_recon')
];
decoderNet = dlnetwork(layerGraph(decoderLayers));
% === Classifier ===
classifierLayers = [
featureInputLayer(latentDim, 'Name', 'class_input')
fullyConnectedLayer(hiddenUnits, 'Name', 'fc_class_hidden')
reluLayer('Name', 'relu_class')
fullyConnectedLayer(numClasses, 'Name', 'fc_class')
softmaxLayer('Name', 'softmax_out')
];
classifierNet = dlnetwork(layerGraph(classifierLayers));
function loss = computeTotalLoss(dlX, dlY, encoderNet, decoderNet, classifierNet, sequenceLength)
% All operations inside this function are traced
[mu, logvar] = encodeLatents(dlX, encoderNet);
z = sampleLatents(mu, logvar); % Sampling with reparameterization
loss = computeLoss(dlX, dlY, z, mu, logvar, decoderNet, classifierNet, sequenceLength);
end
function [mu, logvar] = encodeLatents(dlX, encoderNet)
mu = forward(encoderNet, dlX, 'Outputs', 'fc_mu');
logvar = forward(encoderNet, dlX, 'Outputs', 'fc_logvar');
end
function z = sampleLatents(mu, logvar)
eps = dlarray(randn(size(mu), 'like', mu)); % Traced random noise
z = mu + exp(0.5 * logvar) .* eps;
end
function loss = computeLoss(dlX, dlY, z, mu, logvar, decoderNet, classifierNet, sequenceLength)
zRepeated = repmat(z, 1, 1, sequenceLength);
dlZSeq = dlarray(zRepeated, 'CBT');
reconOut = forward(decoderNet, dlZSeq, 'Outputs', 'fc_recon');
classOut = forward(classifierNet, dlarray(z, 'CB'), 'Outputs', 'softmax_out');
reconLoss = mse(reconOut, dlX);
classLoss = crossentropy(classOut, dlY);
klLoss = -0.5 * sum(1 + logvar - mu.^2 - exp(logvar), 'all');
loss = reconLoss + classLoss + klLoss;
end
for epoch = 1:numEpochs
idx = randperm(numel(XTrain));
totalEpochLoss = 0;
numBatches = 0;
for i = 1:miniBatchSize:numel(XTrain)
batchIdx = idx(i:min(i+miniBatchSize-1, numel(XTrain)));
XBatch = XTrain(batchIdx);
YBatch = YTrain(batchIdx, :);
% === Format Inputs ===
XCat = cat(3, XBatch{:});
XCat = permute(XCat, [1, 3, 2]);
dlX = dlarray(XCat, 'CBT');
dlY = dlarray(YBatch', 'CB');
totalLoss = dlfeval(@computeTotalLoss, dlX, dlY, encoderNet, decoderNet, classifierNet, sequenceLength);
gradients = dlgradient(totalLoss, encoderNet.Learnables);
Error using dlarray/dlgradient (line 105)
Value to differentiate is not traced. It must be a traced real dlarray scalar. Use dlgradient inside a function called by dlfeval to trace the variables.

채택된 답변

Torsten
Torsten 2025년 8월 22일
이동: Torsten 2025년 8월 22일
Did you read the documentation on how "dlgradient" is to be applied ?
You will have to call "dlgradient" inside the function "computeTotalLoss", not outside from it.
  댓글 수: 2
SIMON
SIMON 2025년 8월 23일
Ok thanks this is news to me thankyou. I will do that. Thanks so much.
Catalytic
Catalytic 2025년 8월 23일
@SIMON, you should click accept on Torsten's answer if it solves your problem.

댓글을 달려면 로그인하십시오.

추가 답변 (1개)

Matt J
Matt J 2025년 8월 22일
편집: Matt J 2025년 8월 22일
As the error message says, the call to dlgradient must occur within the function (in this case computeTotalLoss) called by dlfeval, but that is not what you've done.
More confusingly, it appears that computeTotalLoss and computeLoss call each other in a recursive loop. That will create problems as well, I imagine. You are not meant to be calling dlfeval inside the loss function. It is meant to be called externally, in your training loop.
  댓글 수: 2
SIMON
SIMON 2025년 8월 23일
The two functions are not recursive computeLoss is called by computeTotalLoss which is inside dfeval. However the gradients are calculated outside of dlfeval - so I may try putting it inside this thanks.
Torsten
Torsten 2025년 8월 23일
You have to put "dlgradient" inside "computeTotalLoss". Then you can call "dlfeval" from somewhere outside "computeTotalLoss" and "computeLoss".

댓글을 달려면 로그인하십시오.

카테고리

Help CenterFile Exchange에서 Custom Training Loops에 대해 자세히 알아보기

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!

Translated by