How do I use trainnetwork() to retrain a pre-trained model?
조회 수: 6 (최근 30일)
이전 댓글 표시
How can I replace the decoder and regression layers in my pretrained CAE model with fully connected layers, softmax layers and classification layers to retrain the model into a classifier?
This is the model I created.
lgraph = layerGraph();
tempLayers = [
imageInputLayer([224 224 3],"Name","imageinput")
convolution2dLayer([3 3],256,"Name","conv_1","Padding","same","Stride",[2 2])
reluLayer("Name","relu_1")
maxPooling2dLayer([1 1],"Name","maxpoolForUnpool_3","HasUnpoolingOutputs",true,"Padding","same")];
lgraph = addLayers(lgraph,tempLayers);
tempLayers = [
convolution2dLayer([3 3],128,"Name","conv_2","Padding","same","Stride",[2 2])
reluLayer("Name","relu_2")
maxPooling2dLayer([1 1],"Name","maxpoolForUnpool_2","HasUnpoolingOutputs",true,"Padding","same")];
lgraph = addLayers(lgraph,tempLayers);
tempLayers = [
convolution2dLayer([3 3],64,"Name","conv_3","Padding","same","Stride",[2 2])
reluLayer("Name","relu_3")
maxPooling2dLayer([1 1],"Name","maxpoolForUnpool_1","HasUnpoolingOutputs",true,"Padding","same")];
lgraph = addLayers(lgraph,tempLayers);
tempLayers = [
transposedConv2dLayer([3 3],64,"Name","transposed-conv_1","Cropping","same")
reluLayer("Name","relu_4")];
lgraph = addLayers(lgraph,tempLayers);
tempLayers = [
maxUnpooling2dLayer("Name","maxunpool_1")
transposedConv2dLayer([3 3],128,"Name","transposed-conv_2","Cropping","same","Stride",[2 2])
reluLayer("Name","relu_5")];
lgraph = addLayers(lgraph,tempLayers);
tempLayers = [
maxUnpooling2dLayer("Name","maxunpool_2")
transposedConv2dLayer([3 3],256,"Name","transposed-conv_3","Cropping","same","Stride",[2 2])
reluLayer("Name","relu_6")];
lgraph = addLayers(lgraph,tempLayers);
tempLayers = [
maxUnpooling2dLayer("Name","maxunpool_3")
transposedConv2dLayer([3 3],3,"Name","transposed-conv_4","Cropping","same","Stride",[2 2])
reluLayer("Name","relu_7")
regressionLayer("Name","regressionoutput")];
lgraph = addLayers(lgraph,tempLayers);
% clean up helper variable
clear tempLayers;
lgraph = connectLayers(lgraph,"maxpoolForUnpool_3/out","conv_2");
lgraph = connectLayers(lgraph,"maxpoolForUnpool_3/indices","maxunpool_3/indices");
lgraph = connectLayers(lgraph,"maxpoolForUnpool_3/size","maxunpool_3/size");
lgraph = connectLayers(lgraph,"maxpoolForUnpool_2/out","conv_3");
lgraph = connectLayers(lgraph,"maxpoolForUnpool_2/indices","maxunpool_2/indices");
lgraph = connectLayers(lgraph,"maxpoolForUnpool_2/size","maxunpool_2/size");
lgraph = connectLayers(lgraph,"maxpoolForUnpool_1/out","transposed-conv_1");
lgraph = connectLayers(lgraph,"maxpoolForUnpool_1/indices","maxunpool_1/indices");
lgraph = connectLayers(lgraph,"maxpoolForUnpool_1/size","maxunpool_1/size");
lgraph = connectLayers(lgraph,"relu_4","maxunpool_1/in");
lgraph = connectLayers(lgraph,"relu_5","maxunpool_2/in");
lgraph = connectLayers(lgraph,"relu_6","maxunpool_3/in");
댓글 수: 0
답변 (1개)
Rahul
2022년 10월 13일
Go the Apps --> Deep Network Designer --> Blank Network.
Once you create your network by dragging and dropping the layers and connecting them, click on Export --> Generate Code. This should create your model in a very simple way. If you are still unsure, please send the entire architecture, I will create the network for you.
댓글 수: 5
Rahul
2022년 10월 14일
Below code is just the demo CNN architecture. You can refer this to build your own CNN architecture.
layers = [ ...
imageInputLayer([28 28 1]) % image input layer
convolution2dLayer(5,20) % 2D convolutional layer
reluLayer("Name","relu1") % ReLU activation layer
maxPooling2dLayer(2,'Stride',2) % 2D max pooling layer
fullyConnectedLayer(2048,"Name","FC1") % Fully connected layer 1
reluLayer("Name","relu2") % ReLU activation layer
fullyConnectedLayer(1024,"Name","FC2") % Fully connected layer 2
reluLayer("Name","relu3") % ReLU activation layer
fullyConnectedLayer(10) % Fully connected layer 3
% (10 represented number of classes)
softmaxLayer % Softmax activation layer to calculate class probability
classificationLayer]
% Classification layer to let the system know that it is a classification
% task.
참고 항목
Community Treasure Hunt
Find the treasures in MATLAB Central and discover how the community can help you!
Start Hunting!