How to improve accuracy of SqueezeNet convolutional neural network?

조회 수: 26 (최근 30일)
Jenifer NG
Jenifer NG 2022년 7월 17일
답변: Rahul 2022년 10월 11일
Dear All,
I am new in machine learning segment.
I just want to use a CNN to classify my images. In this case, I used SqueezeNet.
But the Validation accuracy is around 67% as attached picture.
Could anyone help me some advise or any other model ?
I attached my training data and model in.mat file
link to download training data:
https://drive.matlab.com/sharing/09ff1a1b-446b-4c2a-9599-894594dcbf18
I refer to this link.
Transfer Learning with Deep Network Designer - MATLAB & Simulink. (n.d.). Www.mathworks.com. Retrieved July 17, 2022, from https://www.mathworks.com/help/deeplearning/ug/transfer-learning-with-deep-network-designer.html
trainingSetup = load("https://www.mathworks.com/matlabcentral/answers/uploaded_files/1068045/model1.mat");
imdsTrain = imageDatastore("C:\Users\efml\Desktop\Image_analysis","IncludeSubfolders",true,"LabelSource","foldernames");
[imdsTrain, imdsValidation] = splitEachLabel(imdsTrain,0.7);
imageAugmenter = imageDataAugmenter(...
"RandRotation",[-90 90],...
"RandScale",[1 2],...
"RandXReflection",true);
% Resize the images to match the network input layer.
augimdsTrain = augmentedImageDatastore([300 300 1],imdsTrain,"DataAugmentation",imageAugmenter);
augimdsValidation = augmentedImageDatastore([300 300 1],imdsValidation);
opts = trainingOptions("sgdm",...
"ExecutionEnvironment","auto",...
"InitialLearnRate",0.0001,...
"MaxEpochs",8,...
"MiniBatchSize",16,...
"Shuffle","every-epoch",...
"ValidationFrequency",5,...
"Plots","training-progress",...
"ValidationData",augimdsValidation);
lgraph = layerGraph();
tempLayers = [
imageInputLayer([300 300 1],"Name","data")
convolution2dLayer([3 3],64,"Name","conv1","Stride",[2 2])
reluLayer("Name","relu_conv1")
maxPooling2dLayer([3 3],"Name","pool1","Stride",[2 2])
convolution2dLayer([1 1],16,"Name","fire2-squeeze1x1","Bias",trainingSetup.fire2_squeeze1x1.Bias,"Weights",trainingSetup.fire2_squeeze1x1.Weights)
reluLayer("Name","fire2-relu_squeeze1x1")];
lgraph = addLayers(lgraph,tempLayers);
tempLayers = [
convolution2dLayer([1 1],64,"Name","fire2-expand1x1","Bias",trainingSetup.fire2_expand1x1.Bias,"Weights",trainingSetup.fire2_expand1x1.Weights)
reluLayer("Name","fire2-relu_expand1x1")];
lgraph = addLayers(lgraph,tempLayers);
tempLayers = [
convolution2dLayer([3 3],64,"Name","fire2-expand3x3","Padding",[1 1 1 1],"Bias",trainingSetup.fire2_expand3x3.Bias,"Weights",trainingSetup.fire2_expand3x3.Weights)
reluLayer("Name","fire2-relu_expand3x3")];
lgraph = addLayers(lgraph,tempLayers);
tempLayers = [
depthConcatenationLayer(2,"Name","fire2-concat")
convolution2dLayer([1 1],16,"Name","fire3-squeeze1x1","Bias",trainingSetup.fire3_squeeze1x1.Bias,"Weights",trainingSetup.fire3_squeeze1x1.Weights)
reluLayer("Name","fire3-relu_squeeze1x1")];
lgraph = addLayers(lgraph,tempLayers);
tempLayers = [
convolution2dLayer([3 3],64,"Name","fire3-expand3x3","Padding",[1 1 1 1],"Bias",trainingSetup.fire3_expand3x3.Bias,"Weights",trainingSetup.fire3_expand3x3.Weights)
reluLayer("Name","fire3-relu_expand3x3")];
lgraph = addLayers(lgraph,tempLayers);
tempLayers = [
convolution2dLayer([1 1],64,"Name","fire3-expand1x1","Bias",trainingSetup.fire3_expand1x1.Bias,"Weights",trainingSetup.fire3_expand1x1.Weights)
reluLayer("Name","fire3-relu_expand1x1")];
lgraph = addLayers(lgraph,tempLayers);
tempLayers = [
depthConcatenationLayer(2,"Name","fire3-concat")
maxPooling2dLayer([3 3],"Name","pool3","Padding",[0 1 0 1],"Stride",[2 2])
convolution2dLayer([1 1],32,"Name","fire4-squeeze1x1","Bias",trainingSetup.fire4_squeeze1x1.Bias,"Weights",trainingSetup.fire4_squeeze1x1.Weights)
reluLayer("Name","fire4-relu_squeeze1x1")];
lgraph = addLayers(lgraph,tempLayers);
tempLayers = [
convolution2dLayer([3 3],128,"Name","fire4-expand3x3","Padding",[1 1 1 1],"Bias",trainingSetup.fire4_expand3x3.Bias,"Weights",trainingSetup.fire4_expand3x3.Weights)
reluLayer("Name","fire4-relu_expand3x3")];
lgraph = addLayers(lgraph,tempLayers);
tempLayers = [
convolution2dLayer([1 1],128,"Name","fire4-expand1x1","Bias",trainingSetup.fire4_expand1x1.Bias,"Weights",trainingSetup.fire4_expand1x1.Weights)
reluLayer("Name","fire4-relu_expand1x1")];
lgraph = addLayers(lgraph,tempLayers);
tempLayers = [
depthConcatenationLayer(2,"Name","fire4-concat")
convolution2dLayer([1 1],32,"Name","fire5-squeeze1x1","Bias",trainingSetup.fire5_squeeze1x1.Bias,"Weights",trainingSetup.fire5_squeeze1x1.Weights)
reluLayer("Name","fire5-relu_squeeze1x1")];
lgraph = addLayers(lgraph,tempLayers);
tempLayers = [
convolution2dLayer([3 3],128,"Name","fire5-expand3x3","Padding",[1 1 1 1],"Bias",trainingSetup.fire5_expand3x3.Bias,"Weights",trainingSetup.fire5_expand3x3.Weights)
reluLayer("Name","fire5-relu_expand3x3")];
lgraph = addLayers(lgraph,tempLayers);
tempLayers = [
convolution2dLayer([1 1],128,"Name","fire5-expand1x1","Bias",trainingSetup.fire5_expand1x1.Bias,"Weights",trainingSetup.fire5_expand1x1.Weights)
reluLayer("Name","fire5-relu_expand1x1")];
lgraph = addLayers(lgraph,tempLayers);
tempLayers = [
depthConcatenationLayer(2,"Name","fire5-concat")
maxPooling2dLayer([3 3],"Name","pool5","Padding",[0 1 0 1],"Stride",[2 2])
convolution2dLayer([1 1],48,"Name","fire6-squeeze1x1","Bias",trainingSetup.fire6_squeeze1x1.Bias,"Weights",trainingSetup.fire6_squeeze1x1.Weights)
reluLayer("Name","fire6-relu_squeeze1x1")];
lgraph = addLayers(lgraph,tempLayers);
tempLayers = [
convolution2dLayer([3 3],192,"Name","fire6-expand3x3","Padding",[1 1 1 1],"Bias",trainingSetup.fire6_expand3x3.Bias,"Weights",trainingSetup.fire6_expand3x3.Weights)
reluLayer("Name","fire6-relu_expand3x3")];
lgraph = addLayers(lgraph,tempLayers);
tempLayers = [
convolution2dLayer([1 1],192,"Name","fire6-expand1x1","Bias",trainingSetup.fire6_expand1x1.Bias,"Weights",trainingSetup.fire6_expand1x1.Weights)
reluLayer("Name","fire6-relu_expand1x1")];
lgraph = addLayers(lgraph,tempLayers);
tempLayers = [
depthConcatenationLayer(2,"Name","fire6-concat")
convolution2dLayer([1 1],48,"Name","fire7-squeeze1x1","Bias",trainingSetup.fire7_squeeze1x1.Bias,"Weights",trainingSetup.fire7_squeeze1x1.Weights)
reluLayer("Name","fire7-relu_squeeze1x1")];
lgraph = addLayers(lgraph,tempLayers);
tempLayers = [
convolution2dLayer([1 1],192,"Name","fire7-expand1x1","Bias",trainingSetup.fire7_expand1x1.Bias,"Weights",trainingSetup.fire7_expand1x1.Weights)
reluLayer("Name","fire7-relu_expand1x1")];
lgraph = addLayers(lgraph,tempLayers);
tempLayers = [
convolution2dLayer([3 3],192,"Name","fire7-expand3x3","Padding",[1 1 1 1],"Bias",trainingSetup.fire7_expand3x3.Bias,"Weights",trainingSetup.fire7_expand3x3.Weights)
reluLayer("Name","fire7-relu_expand3x3")];
lgraph = addLayers(lgraph,tempLayers);
tempLayers = [
depthConcatenationLayer(2,"Name","fire7-concat")
convolution2dLayer([1 1],64,"Name","fire8-squeeze1x1","Bias",trainingSetup.fire8_squeeze1x1.Bias,"Weights",trainingSetup.fire8_squeeze1x1.Weights)
reluLayer("Name","fire8-relu_squeeze1x1")];
lgraph = addLayers(lgraph,tempLayers);
tempLayers = [
convolution2dLayer([3 3],256,"Name","fire8-expand3x3","Padding",[1 1 1 1],"Bias",trainingSetup.fire8_expand3x3.Bias,"Weights",trainingSetup.fire8_expand3x3.Weights)
reluLayer("Name","fire8-relu_expand3x3")];
lgraph = addLayers(lgraph,tempLayers);
tempLayers = [
convolution2dLayer([1 1],256,"Name","fire8-expand1x1","Bias",trainingSetup.fire8_expand1x1.Bias,"Weights",trainingSetup.fire8_expand1x1.Weights)
reluLayer("Name","fire8-relu_expand1x1")];
lgraph = addLayers(lgraph,tempLayers);
tempLayers = [
depthConcatenationLayer(2,"Name","fire8-concat")
convolution2dLayer([1 1],64,"Name","fire9-squeeze1x1","Bias",trainingSetup.fire9_squeeze1x1.Bias,"Weights",trainingSetup.fire9_squeeze1x1.Weights)
reluLayer("Name","fire9-relu_squeeze1x1")];
lgraph = addLayers(lgraph,tempLayers);
tempLayers = [
convolution2dLayer([1 1],256,"Name","fire9-expand1x1","Bias",trainingSetup.fire9_expand1x1.Bias,"Weights",trainingSetup.fire9_expand1x1.Weights)
reluLayer("Name","fire9-relu_expand1x1")];
lgraph = addLayers(lgraph,tempLayers);
tempLayers = [
convolution2dLayer([3 3],256,"Name","fire9-expand3x3","Padding",[1 1 1 1],"Bias",trainingSetup.fire9_expand3x3.Bias,"Weights",trainingSetup.fire9_expand3x3.Weights)
reluLayer("Name","fire9-relu_expand3x3")];
lgraph = addLayers(lgraph,tempLayers);
tempLayers = [
depthConcatenationLayer(2,"Name","fire9-concat")
dropoutLayer(0.5,"Name","drop9")
convolution2dLayer([1 1],3,"Name","conv10","BiasLearnRateFactor",10,"WeightLearnRateFactor",10)
reluLayer("Name","relu_conv10")
globalAveragePooling2dLayer("Name","pool10")
softmaxLayer("Name","prob")
classificationLayer("Name","ClassificationLayer_predictions")];
lgraph = addLayers(lgraph,tempLayers);
% clean up helper variable
clear tempLayers;
lgraph = connectLayers(lgraph,"fire2-relu_squeeze1x1","fire2-expand1x1");
lgraph = connectLayers(lgraph,"fire2-relu_squeeze1x1","fire2-expand3x3");
lgraph = connectLayers(lgraph,"fire2-relu_expand1x1","fire2-concat/in1");
lgraph = connectLayers(lgraph,"fire2-relu_expand3x3","fire2-concat/in2");
lgraph = connectLayers(lgraph,"fire3-relu_squeeze1x1","fire3-expand3x3");
lgraph = connectLayers(lgraph,"fire3-relu_squeeze1x1","fire3-expand1x1");
lgraph = connectLayers(lgraph,"fire3-relu_expand1x1","fire3-concat/in1");
lgraph = connectLayers(lgraph,"fire3-relu_expand3x3","fire3-concat/in2");
lgraph = connectLayers(lgraph,"fire4-relu_squeeze1x1","fire4-expand3x3");
lgraph = connectLayers(lgraph,"fire4-relu_squeeze1x1","fire4-expand1x1");
lgraph = connectLayers(lgraph,"fire4-relu_expand3x3","fire4-concat/in2");
lgraph = connectLayers(lgraph,"fire4-relu_expand1x1","fire4-concat/in1");
lgraph = connectLayers(lgraph,"fire5-relu_squeeze1x1","fire5-expand3x3");
lgraph = connectLayers(lgraph,"fire5-relu_squeeze1x1","fire5-expand1x1");
lgraph = connectLayers(lgraph,"fire5-relu_expand3x3","fire5-concat/in2");
lgraph = connectLayers(lgraph,"fire5-relu_expand1x1","fire5-concat/in1");
lgraph = connectLayers(lgraph,"fire6-relu_squeeze1x1","fire6-expand3x3");
lgraph = connectLayers(lgraph,"fire6-relu_squeeze1x1","fire6-expand1x1");
lgraph = connectLayers(lgraph,"fire6-relu_expand3x3","fire6-concat/in2");
lgraph = connectLayers(lgraph,"fire6-relu_expand1x1","fire6-concat/in1");
lgraph = connectLayers(lgraph,"fire7-relu_squeeze1x1","fire7-expand1x1");
lgraph = connectLayers(lgraph,"fire7-relu_squeeze1x1","fire7-expand3x3");
lgraph = connectLayers(lgraph,"fire7-relu_expand3x3","fire7-concat/in2");
lgraph = connectLayers(lgraph,"fire7-relu_expand1x1","fire7-concat/in1");
lgraph = connectLayers(lgraph,"fire8-relu_squeeze1x1","fire8-expand3x3");
lgraph = connectLayers(lgraph,"fire8-relu_squeeze1x1","fire8-expand1x1");
lgraph = connectLayers(lgraph,"fire8-relu_expand3x3","fire8-concat/in2");
lgraph = connectLayers(lgraph,"fire8-relu_expand1x1","fire8-concat/in1");
lgraph = connectLayers(lgraph,"fire9-relu_squeeze1x1","fire9-expand1x1");
lgraph = connectLayers(lgraph,"fire9-relu_squeeze1x1","fire9-expand3x3");
lgraph = connectLayers(lgraph,"fire9-relu_expand1x1","fire9-concat/in1");
lgraph = connectLayers(lgraph,"fire9-relu_expand3x3","fire9-concat/in2");
[net, traininfo] = trainNetwork(augimdsTrain,lgraph,opts);
Thanks all,

채택된 답변

Amanjit Dulai
Amanjit Dulai 2022년 9월 11일
Getting good performance with deep learning can take some experimenting. I tried some simple things with the data you linked to:
  • Transfer learning with a larger network (resnet18).
  • Converting the input images to colour so that we could use the original weights for the first convolution layer.
With these changes, I was able to get validation accuracy in the high 90s. Below is my code:
imdsTrain = imageDatastore("Image_analysis","IncludeSubfolders",true,"LabelSource","foldernames");
[imdsTrain, imdsValidation] = splitEachLabel(imdsTrain,0.7);
imageAugmenter = imageDataAugmenter(...
"RandRotation",[-90 90],...
"RandScale",[1 2],...
"RandXReflection",true);
% Resize the images to match the network input layer.
augimdsTrain = augmentedImageDatastore( [300 300], imdsTrain, ...
"DataAugmentation", imageAugmenter, ...
"ColorPreprocessing", "gray2rgb" );
augimdsValidation = augmentedImageDatastore( [300 300], imdsValidation, ...
"ColorPreprocessing", "gray2rgb" );
opts = trainingOptions("sgdm",...
"ExecutionEnvironment","auto",...
"InitialLearnRate",0.01,...
"MaxEpochs",8,...
"MiniBatchSize",16,...
"Shuffle","every-epoch",...
"ValidationFrequency",5,...
"Plots","training-progress",...
"ValidationData",augimdsValidation);
net = resnet18();
lg = layerGraph(net);
lg = replaceLayer(lg, "data", ...
imageInputLayer([300 300 3], ...
Normalization="zscore"));
lg = removeLayers(lg, ["fc1000" "prob" "ClassificationLayer_predictions"]);
lg = addLayers(lg, [
fullyConnectedLayer(3, ...
Name="fc")
softmaxLayer()
classificationLayer()
]);
lg = connectLayers(lg, "pool5", "fc");
[net, traininfo] = trainNetwork(augimdsTrain,lg,opts);

추가 답변 (1개)

Rahul
Rahul 2022년 10월 11일
Obtaining good performance for a dataset using deep learning architectures is dependent on vartious parameters such as type of CNN architecture that you are using, the number of epochs, optimization technique, learning rate etc. These hyperparameters play an important role getting the best performance of the CNN model on your data.
Now checking each and every hyperparameter by trial-and-error method is very cumbersome. Changing the hyperparameter and running the code again and saving the results is time-consuming. MathWorks has given a solution for this with Experiment Manager App. The Experiment Manager app enables you to create deep learning experiments to train networks under multiple initial conditions and compare the results. You can visit following links on YouTube to have insights on the same.
Experiment Manager provides visualization tools such as training plots and confusion matrices, filters to refine your experiment results, and annotations to record your observations. To improve reproducibility, every time that you run an experiment, Experiment Manager stores a copy of the experiment definition. You can access past experiment definitions to keep track of the hyperparameter combinations that produce each of your results.

카테고리

Help CenterFile Exchange에서 Image Data Workflows에 대해 자세히 알아보기

제품


릴리스

R2022a

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!

Translated by