Update BatchNorm Layer State in Siamese netwrok with custom loop for triplet and contrastive loss
조회 수: 7(최근 30일)
HI everyone, I'm trying to implement a siamese network for face verification. I'm using as a subnetwork a Resnet18 pretrained on my dataset and I'm trying to implement the triplet loss and contrstive loss. The major problem is due to the batch normalization layer in my subnetwork that need to be updated durine the training fase using
But searching on mathworks tutorials, i found the update using only the Crossentropy with one dlarray as input in the forward function that return the state
function [loss,gradients,state] = modelLoss(net,X,T)
[Y,state] = forward(net,X);
At the moment this is my training loop for Contrastive loss, there is another one similar for the triplet loss that thake 3 images at time
for iteration = 1:numIterations
[X1,X2,pairLabels] = GetSiameseBatch(IMGS, miniBatchSize);
% Convert mini-batch of data to dlarray. Specify the dimension labels
% 'SSCB' (spatial, spatial, channel, batch) for image data
dlX1 = dlarray(single(X1),'SSCB');
dlX2 = dlarray(single(X2),'SSCB');
% clear X1 X2
% I load the pairs into the GPU memory
if (executionEnvironment == "auto" && canUseGPU) || executionEnvironment == "gpu"
dlX1 = gpuArray(dlX1);
dlX2 = gpuArray(dlX2);
% Evaluate the model gradients and the generator state using
% dlfeval and the modelGradients functions
[loss,gradientsSubnet,state] = dlfeval(@modelLoss,dlnet,dlX1,dlX2,pairLabels);
dlnet.State = state;
% Update the Siamese subnetwork parameters. Scope: train the last fc
% for 128 dim features vector
[dlnet,trailingAvgSubnet,trailingAvgSqSubnet] = ...
D = duration(0,0,toc(start),Format="hh:mm:ss");
lossValue = double(gather(extractdata(loss)));
% lossValue = double(loss);
title("Elapsed: " + string(D))
And the model loss is
function [loss,gradientsSubnet,state] = modelLoss(net,X1,X2,pairLabels)
% Pass the image pair through the network.
[F1,F2,state] = ForwardSiamese(net,X1,X2);
% Calculate binary cross-entropy loss.
margin = 1;
loss = ContrastiveLoss(F1,F2,pairLabels, margin);
% Calculate gradients of the loss with respect to the network learnable
gradientsSubnet = dlgradient(loss,net.Learnables);
But in the ForwardSiamese function I make the forward of the two dlarray X1 and X2 that contains the batch of pair images (i.e. in X1 there are 32 images, in X2 same, the first image in X1 is paired qith first image in X2 and so on) and compute the loss, but the state to update the batch norm layer where come from?
function [Y1,Y2,state] = ForwardSiamese(dlnet,dlX1,dlX2)
[Y1,state] = forward(dlnet,dlX1);
Y1 = sigmoid(Y1);
% Pass the second image through the twin subnetwork
Y2 = forward(dlnet,dlX2);
Y2 = sigmoid(Y2);
If i compute also [Y2,state] I have 2 states but which one should be used to update the batch norm TrainedMean and TrainedVariance?
Joss Knight 2022년 11월 6일
편집: Joss Knight 2022년 11월 6일
Interesting question! The purpose of batch norm state is to collect statistics about typical inputs. In a normal Siamese workflow, both X1 and X2 are valid inputs, so you ought to be able to update the state with either result.
You could aggregate the state from both or even do an additional pass with both to compute the aggregated state, although this would come with extra performance cost. So
[~,dlnet.State] = forward(dlnet, cat(4,X1,X2));
You can do this after the call to dlfeval.