Deep learning using trainnet for resnet18 : old trainnetwork --> new trainnet
조회 수: 9 (최근 30일)
이전 댓글 표시
Hi,
I copied a matlab example for semantic image segmentation using MATLAB 2022, and the example used 'trainNetwork' and 'deeplabv3plusLayers' for resnet18, and the results were excellent for my images which have 9 labels (background + 8)
However, now the documentation for trainNetwork recommends moving the training command to be 'trainnet', and the documentation for 'deeplabv3plusLayers' recommends moving to ' deeplabv3plus'. When I follow the provided example (e.g https://www.mathworks.com/help/vision/ug/semantic-segmentation-using-deep-learning.html), my segmentation results are noticibly worse:
...especially for my primary ROIs of interest: ROI1, ROI2, and ROI3.
My program which calls trainnet uses the following options. The function call is also shown here:
options = trainingOptions('sgdm',...
'InitialLearnRate',0.1, ...
'Momentum',0.9,...
'L2Regularization',0.0001,...
'MaxEpochs',80,...
'MiniBatchSize',128,...
'LearnRateSchedule','piecewise',...
'Shuffle','every-epoch',...
'GradientThresholdMethod','l2norm',...
'GradientThreshold',0.05, ...
'Plots','training-progress', ...
'VerboseFrequency',10,...
'ExecutionEnvironment','multi-gpu',...
'ValidationData',dsValid,...
'ValidationFrequency',30,...
'ValidationPatience',15,...
'InputDataFormats', {'SSCB'});
[ROISegNet,info] = trainnet(dsTrain,network,@(Y,T) modelLoss(Y,T,classWeights),options);
where the modelLoss subroutine is as provided by MATLAB:
function loss = modelLoss(Y,T,classWeights)
weights = dlarray(classWeights,"C");
mask = ~isnan(T);
T(isnan(T)) = 0;
loss = crossentropy(Y,T,weights,Mask=mask,NormalizationFactor="mask-included");
end
I also note that when I set all classWeights = 1, the results were better than the suggested classweights via the example.
As a novice learning how to do this, can anyone offer suggestions on how to handle trainnet such that the results generated are closer to the old version?
Thank you for your help!
댓글 수: 3
답변 (1개)
稼栋
2024년 6월 15일
It looks like you're facing a challenge with migrating your code from the old syntax using trainNetwork and deeplabv3plusLayers to the new syntax using trainnet and deeplabv3plus. Here are a few suggestions to help you get better results with the new syntax:
- Update the model architecture: Ensure that when you switch to deeplabv3plus, you are using the correct model architecture and configuration. The new syntax might require different settings compared to the old version.
- Check the loss function: The modelLoss function you provided seems correct, but double-check that it aligns with the loss function used in the older version. The loss function can significantly impact the segmentation results.
- Adjust learning rate and other training options: Experiment with different learning rates, momentum values, and other training options in trainingOptions to find settings that work best for your dataset and model.
- Data preprocessing: Ensure that your data preprocessing steps, such as resizing, normalization, and augmentation, are consistent between the old and new versions. Inconsistent preprocessing can lead to differences in model performance.
- Debugging and visualization: Use debugging and visualization techniques to understand why the new model might be performing differently. Visualizing the network's predictions and intermediate outputs can provide insights into potential issues.
- Experiment with different initialization: Try different initialization strategies for the network weights. Sometimes, initializing the weights differently can lead to better convergence and results.
- Use transfer learning: If possible, leverage transfer learning by starting with a pre-trained model (e.g., on ImageNet) and fine-tuning it on your dataset. This can often lead to better performance, especially with limited training data.
- Explore other network architectures: If deeplabv3plus does not yield satisfactory results, consider trying other network architectures or variations that are better suited for semantic segmentation tasks.
By experimenting with these suggestions and carefully comparing the old and new versions, you should be able to improve the performance of your semantic segmentation model with the new syntax.
참고 항목
카테고리
Help Center 및 File Exchange에서 Parallel and Cloud에 대해 자세히 알아보기
Community Treasure Hunt
Find the treasures in MATLAB Central and discover how the community can help you!
Start Hunting!