필터 지우기
필터 지우기

Custom training loop and parameters for semantic segmentation problem

조회 수: 63 (최근 30일)
Manon
Manon 2024년 7월 5일 14:03
답변: praguna manvi 2024년 7월 15일 13:09
Good day
I have been working on a semantic segmentation problem using the Deep Learning Toolbox, with the unetLayers method.
But I now need to tune some parameters that I trust are not accessible through the toolbox. I would like to:
  • Use a custom loss function, sum of a weighted focal loss and a dice loss.
  • Augment data over each epoch, and if possible, use a custom augmentation method, called RLR (Random Local Rotation).
  • Implement a learning rate based on the ‘One cycle’ method.
  • Track the training progress with as much details as possible.
Here is my base code; that worked for training the Unet:
[imds,pxds] = create_ds(folders);
indicesTest = 1:floor(length(imds.Files)/10);
testImds = subset(imds,indicesTest);
testPxds = subset(pxds,indicesTest);
dsTest = combine(testImds,testPxds);
indicesTraining = floor(length(imds.Files)/10):length(imds.Files);
trainingImds = subset(imds,indicesTraining);
trainingPxds = subset(pxds,indicesTraining);
dsTraining = combine(trainingImds,trainingPxds);
imageSize = [128 128 3];
numClasses = 2;
unetNetwork = unetLayers(imageSize,numClasses,EncoderDepth = 3);
opts = trainingOptions ("rmsprop", ...
InitialLearnRate = 1e-3, ...
MaxEpochs = 40, ...
MiniBatchSize = 32, ...
VerboseFrequency = 10, ...
Plots = "training-progress", ...
Shuffle="every-epoch", ...
ValidationData = dsTest, ...
ValidationFrequency=10, ...
OutputNetwork = "best-validation-loss" )
currentNet = trainNetwork(dsTraining,unetNetwork,opts)
create_ds is a function that returns two datastores. imds contains the RGB images and pxds contains categorical images with two classes, that are the masks to each image from imds.
The RLR function returns [image,label], the RGB and categorical images that have been geometrically modified (the dimensions and type are conserved).
Here is the function that returns the custom loss:
function loss = combinedLoss(Y, T, alpha, gamma)
epsilon = 1e-6;
p = sigmoid(Y);
lossPos = -alpha * (1 - p).^gamma .* T .* log(p + epsilon);
lossNeg = -(1 - alpha) * p.^gamma .* (1 - T) .* log(1 - p + epsilon);
weightedFocalLoss = mean(lossPos + lossNeg, 'all');
intersection = sum(Y .* T, 'all');
union = sum(Y, 'all') + sum(T, 'all');
diceCoeff = (2 * intersection + epsilon) / (union + epsilon);
diceLoss = 1 - diceCoeff;
loss = weightedFocalLoss + diceLoss;
end
I have reviewed the guides that explain how to create a custom training loop:
‘ Train network using custom training loop’ and ‘Monitor custom training loop progress’ -and some others- several times, but I still don’t get how to adapt the examples to my semantic segmentation problem, and with the correct behaviour for my input data.
I don’t think sharing the fragments of code I have tried to compose would be helpful, as they are very far from functional.
Any help on the matter, whether contributing to answer my question partially, or completely, would be greatly appreciated.
Have a nice week-end!

답변 (1개)

praguna manvi
praguna manvi 2024년 7월 15일 13:09
Hi, you can find many examples under this documentation for training under a custom for loop:
https://www.mathworks.com/help/deeplearning/examples.html?category=custom-training-loops , there is a working example of GANs & Style transfer which can help this use case.
It is possible to train a neural net with a custom learning schedular based on a specific implementation by passing its learning rate to “rmspropupdate” function (in this case).
And the network can step through a custom loss using its handler passed to “dlfeval” function.
For recording and visualizing training metrics/loss consider using “TrainingProgressMonitorand you could use a custom augmentation method like RLR to process images after loading from datastores at each epoch / specific iteration.
Hope this helps!

카테고리

Help CenterFile Exchange에서 Image Data Workflows에 대해 자세히 알아보기

제품


릴리스

R2023b

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!

Translated by