Data augmentation in CNN

조회 수: 12 (최근 30일)
Srinidhi Gorityala
Srinidhi Gorityala 2020년 5월 21일
답변: Srivardhan Gadila 2020년 5월 29일
Heloo... Iam working on a data set of 300 images of 2 classes. I want to perform the data augmentation. Below is the code i have attached for data augmentation. According to the methodology the accuracy must increase, but in our case it is decreasing from 96% to 87%. could anyone please help me in solving thia problem.
Thank you in advance:)
clc;
clear all;
close all;
myTrainingFolder = 'C:\Users\Admin\Desktop\Major Project\cnn_dataset';
%testingFolder = 'C:\Users\Be Happy\Documents\MATLAB\gtsrbtest';
imds = imageDatastore(myTrainingFolder,'IncludeSubfolders', true, 'LabelSource', 'foldernames');
%testingSet = imageDatastore(testingFolder,'IncludeSubfolders', true, 'LabelSource', 'foldernames');
labelCount = countEachLabel(imds);
numClasses = height(labelCount);
numImagesTraining = numel(imds.Files);
%% Create training and validation sets
[imdsTrainingSet, imdsValidationSet] = splitEachLabel(imds, 0.7, 'randomize');
%% Build a simple CNN
imageSize = [227 227 3];
% Specify the convolutional neural network architecture.
layers = [
imageInputLayer(imageSize)
convolution2dLayer(3,8,'Padding','same')
batchNormalizationLayer
reluLayer
maxPooling2dLayer(2,'Stride',2)
convolution2dLayer(3,16,'Padding','same')
batchNormalizationLayer
reluLayer
maxPooling2dLayer(2,'Stride',2)
convolution2dLayer(3,32,'Padding','same')
batchNormalizationLayer
reluLayer
fullyConnectedLayer(numClasses)
softmaxLayer
classificationLayer];
%% Specify training options
options = trainingOptions('sgdm', ...
'MiniBatchSize',10, ...
'MaxEpochs',6, ...
'InitialLearnRate',1e-4, ...
'Shuffle','every-epoch', ...
'ValidationData',imdsValidationSet, ...
'ValidationFrequency',3, ...
'Verbose',false, ...
'Plots','training-progress');
%% Train the network
net1 = trainNetwork(imdsTrainingSet,layers,options);
analyzeNetwork(net1);
%% Report accuracy of baseline classifier on validation set
YPred = classify(net1,imdsValidationSet);
YValidation = imdsValidationSet.Labels;
imdsAccuracy = sum(YPred == YValidation)/numel(YValidation);
%% Plot confusion matrix
figure, plotconfusion(YValidation,YPred)
%% PART 2: Baseline Classifier with Data Augmentation
%% Create augmented image data store
% Specify data augmentation options and values/ranges
imageAugmenter = imageDataAugmenter( ...
'RandRotation',[-20,20], ...
'RandXTranslation',[-5 5], ...
'RandYTranslation',[-5 5]);
% Apply transformations (using randomly picked values) and build augmented
% data store
augImds = augmentedImageDatastore(imageSize,imdsTrainingSet, ...
'DataAugmentation',imageAugmenter);
% (OPTIONAL) Preview augmentation results
batchedData = preview(augImds);
figure, imshow(imtile(batchedData.input))
%% Train the network.
net2 = trainNetwork(augImds,layers,options);
%% Report accuracy of baseline classifier with image data augmentation
YPred = classify(net2,imdsValidationSet);
YValidation = imdsValidationSet.Labels;
augImdsAccuracy = sum(YPred == YValidation)/numel(YValidation);
%% Plot confusion matrix
figure, plotconfusion(YValidation,YPred)
  댓글 수: 2
Mohammad Sami
Mohammad Sami 2020년 5월 21일
Did you try to train the augmented data for more epochs ?
Srinidhi Gorityala
Srinidhi Gorityala 2020년 5월 21일
Mohammad sami,
No, i did not train them.

댓글을 달려면 로그인하십시오.

채택된 답변

Srivardhan Gadila
Srivardhan Gadila 2020년 5월 29일
The following are few suggestions:
  1. Make sure that all the classes have equal number of observations.
  2. Check how trainNetwork uses an augmented image datastore to transform training data for each epoch: Augment Images for Training with Random Geometric Transformations. Then try training the network with augmentedImageDatastore for more epochs.
  3. Try changing the network architecture itself if there is no improvement in the accuracy when augmentedImageDatastore is used. You can refer to Choose Network Architecture.
  4. Try Using dropout layers & increasing global L2 regularization factor in new architecture. For more information, see dropoutLayer & 'L2Regularization' option in trainingOptions.

추가 답변 (0개)

카테고리

Help CenterFile Exchange에서 Image Data Workflows에 대해 자세히 알아보기

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!

Translated by