How to add augumentation to training andvalidation data?

조회 수: 1 (최근 30일)
Namita Gera
Namita Gera 2023년 2월 14일
답변: Ashu 2023년 2월 21일
Hi,
I am trying to resize and add noise to the training and validation images to be trained in alexnet but I keep coming up with the following error:
The value of 'ValidationData' is invalid. Invalid transform function defined on datastore.
opts = nnet.cnn.TrainingOptionsSGDM(varargin{:});
Caused by:
Too many input arguments.
How can I correct this?
Here is the code I am using:
%load network parameters
params = load("C:\Users\namit\MATLAB Drive\Individualproject\params_2023_02_13__20_34_32.mat");
%load data
digitDatasetPath = fullfile('Tomato - Copy/');
imds = imageDatastore(digitDatasetPath, ...
'IncludeSubfolders',true,'LabelSource','foldernames');
labelCount = countEachLabel(imds)
img = readimage(imds,1);
size(img)
%Divide the data into training and validation data sets
[imdsTrain,imdsValidation] = splitEachLabel(imds,0.7,'randomize');
%resize images
augimdsTrain = augmentedImageDatastore([227 227 3],imdsValidation);
augimdsValidation = augmentedImageDatastore([227 227 3],imdsTrain);
%apply noise to images
dsTrain = transform(augimdsTrain,@preprocessForTraining, IncludeInfo=true);
dsValidation = transform(augimdsValidation,@preprocessForTraining, IncludeInfo=true);
%Define the convolutional neural network architecture.
layers = [
imageInputLayer([227 227 3],"Name","data","Mean",params.data.Mean)
convolution2dLayer([11 11],96,"Name","conv1","BiasLearnRateFactor",2,"Stride",[4 4],"Bias",params.conv1.Bias,"Weights",params.conv1.Weights)
reluLayer("Name","relu1")
crossChannelNormalizationLayer(5,"Name","norm1","K",1)
maxPooling2dLayer([3 3],"Name","pool1","Stride",[2 2])
groupedConvolution2dLayer([5 5],128,2,"Name","conv2","BiasLearnRateFactor",2,"Padding",[2 2 2 2],"Bias",params.conv2.Bias,"Weights",params.conv2.Weights)
reluLayer("Name","relu2")
crossChannelNormalizationLayer(5,"Name","norm2","K",1)
maxPooling2dLayer([3 3],"Name","pool2","Stride",[2 2])
convolution2dLayer([3 3],384,"Name","conv3","BiasLearnRateFactor",2,"Padding",[1 1 1 1],"Bias",params.conv3.Bias,"Weights",params.conv3.Weights)
reluLayer("Name","relu3")
groupedConvolution2dLayer([3 3],192,2,"Name","conv4","BiasLearnRateFactor",2,"Padding",[1 1 1 1],"Bias",params.conv4.Bias,"Weights",params.conv4.Weights)
reluLayer("Name","relu4")
groupedConvolution2dLayer([3 3],128,2,"Name","conv5","BiasLearnRateFactor",2,"Padding",[1 1 1 1],"Bias",params.conv5.Bias,"Weights",params.conv5.Weights)
reluLayer("Name","relu5")
maxPooling2dLayer([3 3],"Name","pool5","Stride",[2 2])
fullyConnectedLayer(4096,"Name","fc6")
reluLayer("Name","relu6")
dropoutLayer(0.5,"Name","drop6")
fullyConnectedLayer(4096,"Name","fc7")
reluLayer("Name","relu7")
dropoutLayer(0.5,"Name","drop7")
fullyConnectedLayer(6,"Name","fc8")
softmaxLayer("Name","prob")
classificationLayer("Name","classoutput")];
options = trainingOptions("sgdm",...
"ExecutionEnvironment","auto",...
"InitialLearnRate",0.001,...
"MaxEpochs",10,...
"Shuffle","every-epoch",...
"Plots","training-progress",...
"ValidationData",dsValidation);
%train network
net = trainNetwork(dsTrain,layers,options);
%function used
function dataOut = preprocessForTraining(data)
dataOut = data;
for idx = 1:size(data,1)
dataOut{idx} = imnoise(data{idx},'salt & pepper');
end
end

답변 (1개)

Ashu
Ashu 2023년 2월 21일
You can investigate the following points to correct your code.
  • There is a Logical problem in your code when you are resizing the images. You can refer to the code mentioned below to correct it.
augimdsTrain = augmentedImageDatastore([227 227 3],imdsTrain);
augimdsValidation = augmentedImageDatastore([227 227 3],imdsValidation);
  • Since you have specified the value of 'IncludeInfo' to be true. In this case, the transformation function must have this signature.
function [dataOut,infoOut] = transformFcn(ds1_data,ds2_data,...dsN_data,ds1_info,ds2_info...dsN_info)
..
end
  • If you don't want to change your 'preprocessForTraining' function, you can just remove the 'IncludeInfo' argument from 'transform'
To learn more about image data augmentation, you can refer to the following link

카테고리

Help CenterFile Exchange에서 Deep Learning Toolbox에 대해 자세히 알아보기

제품

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!

Translated by