How do I create and plot a confusion matrix for my trained convolutional neural network?

조회 수: 13 (최근 30일)
I can't seem to create a confusion matrix for my validation accuracy outcome of my trained convolutional neural network. Below is the code I am using, and thanks in advance for any help!
-----------------------------------------------------------------------------------
clear
rng('shuffle')
outputFolder = fullfile('D:\Large_grains\Training_set');
trainDigitData = imageDatastore(outputFolder,'IncludeSubfolders',true,'LabelSource','foldernames');
outputFolder = fullfile('D:\Large_grains\Validation_set');
testDigitData = imageDatastore(outputFolder,'IncludeSubfolders',true,'LabelSource','foldernames');
inputSize = [224 224 3];
augimdsTrain = augmentedImageDatastore(inputSize,trainDigitData,'ColorPreprocessing','gray2rgb');
augimdsValidation = augmentedImageDatastore(inputSize,testDigitData,'ColorPreprocessing','gray2rgb');
numClasses = 9;
problem2; % load ResNet-18
miniBatchSize = 32;
validationFrequency = floor(numel(trainDigitData.Labels)/miniBatchSize);
options = trainingOptions('sgdm',...
'LearnRateSchedule','piecewise',...
'LearnRateDropFactor',0.1,...
'LearnRateDropPeriod',2,...
'MaxEpochs',10,...
'InitialLearnRate',0.001,...
'MiniBatchSize',miniBatchSize,...
'ValidationData',augimdsValidation, ...
'ValidationFrequency',validationFrequency);
convnet = trainNetwork(augimdsTrain,lgraph,options);
[YPred] = classify(convnet,augimdsValidation);
plotconfusion(augimdsValidation.Labels,YPred)
  댓글 수: 2
Shivam Singh
Shivam Singh 2021년 11월 29일
Hello Steven,
Can you share what is error which you are facing with code? Also, can you share more information about the model ("lgraph") and the dataset used?
Steven Mozarowski
Steven Mozarowski 2021년 11월 29일
Thanks for your response, Shivam! I actually managed to have the script produce a confusion matrix earlier today and was meaning to take this post down when I saw your comment!
To answer your questions:
lgraph is a chart that shows information (like validation accuracy, epoch, time elapsed etc.) as training progresses.
The dataset is a pile of starch grain micrographs I had captured using an imaging flow cytometer. The images are organized in folders on a hard drive with 300 training and 200 validation images per species.
Thanks again for reaching out, I really appreciate it!
-Steven

댓글을 달려면 로그인하십시오.

답변 (1개)

yanqi liu
yanqi liu 2021년 12월 2일
yes,sir,if want get the data information,may be use
[c,cm,ind,per] = confusion(augimdsValidation.Labels,YPred)

카테고리

Help CenterFile Exchange에서 Image Data Workflows에 대해 자세히 알아보기

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!

Translated by