필터 지우기
필터 지우기

Augmentation of data in image processing

조회 수: 5 (최근 30일)
Srinidhi Gorityala
Srinidhi Gorityala 2020년 4월 20일
답변: Srivardhan Gadila 2020년 4월 23일
Helo... iam working on image agumentation using cnn and i have included a datset of both 150 pothole and 150 non-pothole images. Below is the code attached for image agumentation and i have plotted two confusion matrices i.e., before agumentation and after agumentation.
The total number of images in confusion matrix is only 90( TP+TN+FN+TN), but according to the methodology there should be 300(150+150) images could any please help me in understanding this.
clc;
clear all;
close all;
myTrainingFolder = 'C:\Users\Admin\Desktop\Major Project\cnn_dataset';
%testingFolder = 'C:\Users\Be Happy\Documents\MATLAB\gtsrbtest';
imds = imageDatastore(myTrainingFolder,'IncludeSubfolders', true, 'LabelSource', 'foldernames');
%testingSet = imageDatastore(testingFolder,'IncludeSubfolders', true, 'LabelSource', 'foldernames');
labelCount = countEachLabel(imds);
numClasses = height(labelCount);
numImagesTraining = numel(imds.Files);
%% Create training and validation sets
[imdsTrainingSet, imdsValidationSet] = splitEachLabel(imds, 0.7, 'randomize');
%% Build a simple CNN
imageSize = [227 227 3];
% Specify the convolutional neural network architecture.
layers = [
imageInputLayer(imageSize)
convolution2dLayer(3,8,'Padding','same')
batchNormalizationLayer
reluLayer
maxPooling2dLayer(2,'Stride',2)
convolution2dLayer(3,16,'Padding','same')
batchNormalizationLayer
reluLayer
maxPooling2dLayer(2,'Stride',2)
convolution2dLayer(3,32,'Padding','same')
batchNormalizationLayer
reluLayer
fullyConnectedLayer(numClasses)
softmaxLayer
classificationLayer];
%% Specify training options
options = trainingOptions('sgdm', ...
'MiniBatchSize',10, ...
'MaxEpochs',6, ...
'InitialLearnRate',1e-4, ...
'Shuffle','every-epoch', ...
'ValidationData',imdsValidationSet, ...
'ValidationFrequency',3, ...
'Verbose',false, ...
'Plots','training-progress');
%% Train the network
net1 = trainNetwork(imdsTrainingSet,layers,options);
%% Report accuracy of baseline classifier on validation set
YPred = classify(net1,imdsValidationSet);
YValidation = imdsValidationSet.Labels;
imdsAccuracy = sum(YPred == YValidation)/numel(YValidation);
%% Plot confusion matrix
figure, plotconfusion(YValidation,YPred)
%% PART 2: Baseline Classifier with Data Augmentation
%% Create augmented image data store
% Specify data augmentation options and values/ranges
imageAugmenter = imageDataAugmenter( ...
'RandRotation',[-20,20], ...
'RandXTranslation',[-5 5], ...
'RandYTranslation',[-5 5]);
% Apply transformations (using randomly picked values) and build augmented
% data store
augImds = augmentedImageDatastore(imageSize,imdsTrainingSet, ...
'DataAugmentation',imageAugmenter);
% (OPTIONAL) Preview augmentation results
batchedData = preview(augImds);
figure, imshow(imtile(batchedData.input))
%% Train the network.
net2 = trainNetwork(augImds,layers,options);
%% Report accuracy of baseline classifier with image data augmentation
YPred = classify(net2,imdsValidationSet);
YValidation = imdsValidationSet.Labels;
augImdsAccuracy = sum(YPred == YValidation)/numel(YValidation);
%% Plot confusion matrix
figure, plotconfusion(YValidation,YPred)
In first confusion matrix i.e., before agumentation 42+45+0+3 = 90 images
In second confusion matrix i.e., after agumentation 45+34+11+0=90 images

채택된 답변

Srivardhan Gadila
Srivardhan Gadila 2020년 4월 23일
With
plotconfusion(YValidation,YPred)
you will be plotting the confusion matrix only for the validation data which is 30% of the total data i.e., 30% of 300 = 90.
[imdsTrainingSet, imdsValidationSet] = splitEachLabel(imds, 0.7, 'randomize');
In order to plot the confusion matrix for the whole dataset, get the results by running the classify function on imds and then use the plotconfusion on the generated results.
The following code might help you:
YPred = classify(net1,imds);
YValidation = imds.Labels;
imdsAccuracy = sum(YPred == YValidation)/numel(YValidation);
%% Plot confusion matrix
figure, plotconfusion(YValidation,YPred)

추가 답변 (0개)

카테고리

Help CenterFile Exchange에서 Geometric Transformation and Image Registration에 대해 자세히 알아보기

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!

Translated by