need to plot the accuracy vs epoch graph

조회 수: 40(최근 30일)
Md 2022년 11월 13일
답변: Joss Knight 2022년 11월 14일
allImages = imageDatastore('TrainingData', 'IncludeSubfolders', true,...
'LabelSource', 'foldernames');
%% Split data into training and test sets
[trainingImages, testImages] = splitEachLabel(allImages, 0.8, 'randomize');
%% Load Pre-trained Network (AlexNet)
% AlexNet is a pre-trained network trained on 1000 object categories.
% AlexNet is avaliable as a support package on FileExchange.
alex = alexnet;
%% Review Network Architecture
layers = alex.Layers
%% Modify Pre-trained Network
% AlexNet was trained to recognize 1000 classes, we need to modify it to
% recognize just 4 classes.
layers(23) = fullyConnectedLayer(4); % change this based on # of classes
layers(25) = classificationLayer
%% Perform Transfer Learning
% For transfer learning we want to change the weights of the network ever so slightly. How
% much a network is changed during training is controlled by the learning
% rates.
opts = trainingOptions('sgdm', 'InitialLearnRate', 0.001,...
'MaxEpochs', 5, 'MiniBatchSize', 16);
%% Set custom read function
% One of the great things about imageDataStore it lets you specify a
% "custom" read function, in this case it is simply resizing the input
% images to 227x227 pixels which is what AlexNet expects. You can do this by
% specifying a function handle of a function with code to read and
% pre-process the image.
trainingImages.ReadFcn = @readFunctionTrain;
%% Train the Network
% This process usually takes about 5-20 minutes on a desktop GPU.
myNet = trainNetwork(trainingImages, layers, opts);
%% Test Network Performance
% Now let's the test the performance of our new "snack recognizer" on the test set.
testImages.ReadFcn = @readFunctionTrain;
predictedLabels = classify(myNet, testImages);
accuracy = mean(predictedLabels == testImages.Labels)
confusionchart(predictedLabels, testImages.Labels)
Hello, for the code above, I need to plot the accuracy vs epoch graph. How can I do that? Thank you!


Joss Knight
Joss Knight 2022년 11월 14일
Add Plots="training-progress" to your training options.
FWIW, you shouldn't use ReadFcn for resizing images, it dramatically slows down file access. Use augmentedImageDatastore instead.

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!

Translated by