필터 지우기
필터 지우기

Trainnetwork to Trainnet conversion

조회 수: 32 (최근 30일)
Emre Can Ertekin
Emre Can Ertekin 2024년 7월 7일 16:08
답변: Paras Gupta 2024년 7월 15일 15:11
Hi there,
I was using Trainnetwork(https://www.mathworks.com/help/deeplearning/ref/trainnetwork.html#mw_408bdd15-2d34-4c0d-ad91-bc83942f7493) function for my study. However, in 2024b trainnet function(https://www.mathworks.com/help/deeplearning/ref/trainnet.html#mw_ffa5eeae-b6e0-444e-a464-91e257cef95b) is slightly faster in computing. I try to convert my Trainnetwork function to trainnet but i can't managed. How can i convert it? My code is written below. Thank you.
%% Train network part
numClasses = numel(categories(trainImgs.Labels));
dropoutProb = 0.2;
layers = [...%my network layers in here.
%% Training Options
options = trainingOptions('adam', ...
'Plots','training-progress',"MiniBatchSize",64, ...
'ValidationData',valImgs,"ExecutionEnvironment","gpu")
%% Training network
trainednet = trainNetwork(trainImgs,layers,options)
% trainednet = trainnet(trainImgs,layers,"crossentropy",options)
  댓글 수: 1
Matt J
Matt J 2024년 7월 7일 16:54
편집: Matt J 2024년 7월 7일 17:24
Your post doesn't mention what bad behavior you're seeing with trainnet, nor give us enough detail and input to run the code ourselves.
There's nothing in the way you're calling trainnet that looks "wrong".

댓글을 달려면 로그인하십시오.

답변 (1개)

Paras Gupta
Paras Gupta 2024년 7월 15일 15:11
Hi Emre,
I understand that you are experiencing issues when transitioning from the 'trainNetwork' function to the 'trainnet' function in MATLAB.
From the code provided in the question, I assume that you are using the same network for both functions. However, the 'trainnet' function requires a slightly modified network architecture that does not include the output layer in the specified layer array. Instead of using an output layer, a loss function is specified using the 'lossFcn' argument.
The following example code illustrates the difference in the network architectures for both the functions:
%% Load and Preprocess Data
digitDatasetPath = fullfile(matlabroot,'toolbox','nnet','nndemos','nndatasets','DigitDataset');
imds = imageDatastore(digitDatasetPath, ...
'IncludeSubfolders',true,'LabelSource','foldernames');
% Split data into training and validation sets
[trainImgs, valImgs] = splitEachLabel(imds, 0.8, 'randomized');
%% Define Network Architectures
% Network for trainNetwork (includes output layer)
layersTrainNetwork = [
imageInputLayer([28 28 1])
convolution2dLayer(3,8,'Padding','same')
batchNormalizationLayer
reluLayer
maxPooling2dLayer(2,'Stride',2)
dropoutLayer(0.2)
fullyConnectedLayer(10)
softmaxLayer
classificationLayer];
% Network for trainnet (does not include output layer)
% Instead of the output layer, we specify the loss function in the trainnet syntax
layersTrainnet = [
imageInputLayer([28 28 1])
convolution2dLayer(3,8,'Padding','same')
batchNormalizationLayer
reluLayer
maxPooling2dLayer(2,'Stride',2)
dropoutLayer(0.2)
fullyConnectedLayer(10)];
%% Training Options
options = trainingOptions('adam', ...
'Plots','training-progress', ...
'MiniBatchSize',64, ...
'ValidationData',valImgs, ...
'ExecutionEnvironment','gpu');
%% Train Network using trainNetwork
trainednet = trainNetwork(trainImgs, layersTrainNetwork, options);
%% Train Network using trainnet
trainednet_ = trainnet(trainImgs, layersTrainnet, "crossentropy", options);
Please refer to the following documentation links for more information on the differences between 'trainNetwork' and 'trainnet' functions:
Hope this helps resolve the issue.

카테고리

Help CenterFile Exchange에서 Image Data Workflows에 대해 자세히 알아보기

제품


릴리스

R2024a

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!

Translated by