Update BatchNorm Layer State in Siamese netwrok with custom loop for triplet and contrastive loss

조회 수: 7(최근 30일)
HI everyone, I'm trying to implement a siamese network for face verification. I'm using as a subnetwork a Resnet18 pretrained on my dataset and I'm trying to implement the triplet loss and contrstive loss. The major problem is due to the batch normalization layer in my subnetwork that need to be updated durine the training fase using
dlnet.State=state;
But searching on mathworks tutorials, i found the update using only the Crossentropy with one dlarray as input in the forward function that return the state
function [loss,gradients,state] = modelLoss(net,X,T)
[Y,state] = forward(net,X);
....
end
At the moment this is my training loop for Contrastive loss, there is another one similar for the triplet loss that thake 3 images at time
for iteration = 1:numIterations
[X1,X2,pairLabels] = GetSiameseBatch(IMGS, miniBatchSize);
% Convert mini-batch of data to dlarray. Specify the dimension labels
% 'SSCB' (spatial, spatial, channel, batch) for image data
dlX1 = dlarray(single(X1),'SSCB');
dlX2 = dlarray(single(X2),'SSCB');
% clear X1 X2
% I load the pairs into the GPU memory
if (executionEnvironment == "auto" && canUseGPU) || executionEnvironment == "gpu"
dlX1 = gpuArray(dlX1);
dlX2 = gpuArray(dlX2);
end
% Evaluate the model gradients and the generator state using
% dlfeval and the modelGradients functions
[loss,gradientsSubnet,state] = dlfeval(@modelLoss,dlnet,dlX1,dlX2,pairLabels);
dlnet.State = state;
% Update the Siamese subnetwork parameters. Scope: train the last fc
% for 128 dim features vector
[dlnet,trailingAvgSubnet,trailingAvgSqSubnet] = ...
adamupdate(dlnet,gradientsSubnet, ...
trailingAvgSubnet,trailingAvgSqSubnet,iteration,learningRate,gradDecay,gradDecaySq);
D = duration(0,0,toc(start),Format="hh:mm:ss");
lossValue = double(gather(extractdata(loss)));
% lossValue = double(loss);
addpoints(lineLoss,iteration,lossValue);
title("Elapsed: " + string(D))
drawnow
end
And the model loss is
function [loss,gradientsSubnet,state] = modelLoss(net,X1,X2,pairLabels)
% Pass the image pair through the network.
[F1,F2,state] = ForwardSiamese(net,X1,X2);
% Calculate binary cross-entropy loss.
margin = 1;
loss = ContrastiveLoss(F1,F2,pairLabels, margin);
% Calculate gradients of the loss with respect to the network learnable
% parameters.
gradientsSubnet = dlgradient(loss,net.Learnables);
end
But in the ForwardSiamese function I make the forward of the two dlarray X1 and X2 that contains the batch of pair images (i.e. in X1 there are 32 images, in X2 same, the first image in X1 is paired qith first image in X2 and so on) and compute the loss, but the state to update the batch norm layer where come from?
function [Y1,Y2,state] = ForwardSiamese(dlnet,dlX1,dlX2)
[Y1,state] = forward(dlnet,dlX1);
Y1 = sigmoid(Y1);
% Pass the second image through the twin subnetwork
Y2 = forward(dlnet,dlX2);
Y2 = sigmoid(Y2);
end
If i compute also [Y2,state] I have 2 states but which one should be used to update the batch norm TrainedMean and TrainedVariance?

채택된 답변

Joss Knight
Joss Knight 2022년 11월 6일
편집: Joss Knight 2022년 11월 6일
Interesting question! The purpose of batch norm state is to collect statistics about typical inputs. In a normal Siamese workflow, both X1 and X2 are valid inputs, so you ought to be able to update the state with either result.
You could aggregate the state from both or even do an additional pass with both to compute the aggregated state, although this would come with extra performance cost. So
[~,dlnet.State] = forward(dlnet, cat(4,X1,X2));
You can do this after the call to dlfeval.
  댓글 수: 4
Filippo Vascellari
Filippo Vascellari 2022년 11월 14일
Great news for the bug, maybe i could speed up the process.
For the state i need the update because when i use the classification loss i have to train all the network, not only the layers after the pooling in the resnet18 as backbone

댓글을 달려면 로그인하십시오.

추가 답변(0개)

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!

Translated by