Deep learning using trainnet for resnet18 : old trainnetwork --> new trainnet

조회 수: 9 (최근 30일)
Matt K.
Matt K. 2024년 6월 14일
댓글: Matt K. 2024년 6월 19일
Hi,
I copied a matlab example for semantic image segmentation using MATLAB 2022, and the example used 'trainNetwork' and 'deeplabv3plusLayers' for resnet18, and the results were excellent for my images which have 9 labels (background + 8)
However, now the documentation for trainNetwork recommends moving the training command to be 'trainnet', and the documentation for 'deeplabv3plusLayers' recommends moving to ' deeplabv3plus'. When I follow the provided example (e.g https://www.mathworks.com/help/vision/ug/semantic-segmentation-using-deep-learning.html), my segmentation results are noticibly worse:
...especially for my primary ROIs of interest: ROI1, ROI2, and ROI3.
My program which calls trainnet uses the following options. The function call is also shown here:
options = trainingOptions('sgdm',...
'InitialLearnRate',0.1, ...
'Momentum',0.9,...
'L2Regularization',0.0001,...
'MaxEpochs',80,...
'MiniBatchSize',128,...
'LearnRateSchedule','piecewise',...
'Shuffle','every-epoch',...
'GradientThresholdMethod','l2norm',...
'GradientThreshold',0.05, ...
'Plots','training-progress', ...
'VerboseFrequency',10,...
'ExecutionEnvironment','multi-gpu',...
'ValidationData',dsValid,...
'ValidationFrequency',30,...
'ValidationPatience',15,...
'InputDataFormats', {'SSCB'});
[ROISegNet,info] = trainnet(dsTrain,network,@(Y,T) modelLoss(Y,T,classWeights),options);
where the modelLoss subroutine is as provided by MATLAB:
function loss = modelLoss(Y,T,classWeights)
weights = dlarray(classWeights,"C");
mask = ~isnan(T);
T(isnan(T)) = 0;
loss = crossentropy(Y,T,weights,Mask=mask,NormalizationFactor="mask-included");
end
I also note that when I set all classWeights = 1, the results were better than the suggested classweights via the example.
As a novice learning how to do this, can anyone offer suggestions on how to handle trainnet such that the results generated are closer to the old version?
Thank you for your help!
  댓글 수: 3
Matt K.
Matt K. 2024년 6월 18일
편집: Matt J 2024년 6월 18일
I checked and I reset the weights to match the example
% See the distribution of class labels
tbl=countEachLabel(pxds);
% Use class weighting to balance the classes.
imageFreq = tbl.PixelCount ./ tbl.ImagePixelCount;
classWeights = median(imageFreq) ./ imageFreq;
And this appeared to greatly help in the end:
...but the net result of the segmentation wasn't very good (note the 'bleeding' of labels):
..and shows how well the previous method worked:
With respect to the answer provided by another user below.... wasn't able to discern the loss function for the previous method of:
lgraph = deeplabv3plusLayers(imageSize,numClasses,'resnet18');
[net,info] = trainNetwork(dsTrain,lgraph,options);
Is the modelLoss function in my original post identical/similar to the previous method? If not, what do you recommend?
Thank you for your help!
-Matt
Matt K.
Matt K. 2024년 6월 19일
Update:
I followed the example here: deeplabv3plus, swapped out my data for the triangle data and got really good visual results on my first pass:
I'm not sure why the other method which used the separate modelLoss function didn't work nearly as well, but I now have something that works. Thank you @Joss Knight and @稼栋 for your helpful suggestions.

댓글을 달려면 로그인하십시오.

답변 (1개)

稼栋
稼栋 2024년 6월 15일
It looks like you're facing a challenge with migrating your code from the old syntax using trainNetwork and deeplabv3plusLayers to the new syntax using trainnet and deeplabv3plus. Here are a few suggestions to help you get better results with the new syntax:
  1. Update the model architecture: Ensure that when you switch to deeplabv3plus, you are using the correct model architecture and configuration. The new syntax might require different settings compared to the old version.
  2. Check the loss function: The modelLoss function you provided seems correct, but double-check that it aligns with the loss function used in the older version. The loss function can significantly impact the segmentation results.
  3. Adjust learning rate and other training options: Experiment with different learning rates, momentum values, and other training options in trainingOptions to find settings that work best for your dataset and model.
  4. Data preprocessing: Ensure that your data preprocessing steps, such as resizing, normalization, and augmentation, are consistent between the old and new versions. Inconsistent preprocessing can lead to differences in model performance.
  5. Debugging and visualization: Use debugging and visualization techniques to understand why the new model might be performing differently. Visualizing the network's predictions and intermediate outputs can provide insights into potential issues.
  6. Experiment with different initialization: Try different initialization strategies for the network weights. Sometimes, initializing the weights differently can lead to better convergence and results.
  7. Use transfer learning: If possible, leverage transfer learning by starting with a pre-trained model (e.g., on ImageNet) and fine-tuning it on your dataset. This can often lead to better performance, especially with limited training data.
  8. Explore other network architectures: If deeplabv3plus does not yield satisfactory results, consider trying other network architectures or variations that are better suited for semantic segmentation tasks.
By experimenting with these suggestions and carefully comparing the old and new versions, you should be able to improve the performance of your semantic segmentation model with the new syntax.

카테고리

Help CenterFile Exchange에서 Parallel and Cloud에 대해 자세히 알아보기

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!

Translated by