필터 지우기
필터 지우기

predict function not working in custom training loop

조회 수: 10 (최근 30일)
Isabella
Isabella 2022년 10월 25일
댓글: Isabella 2022년 10월 25일
I am building a custom training loop for a simple LSTM classification network because I need a custom loss function (specifically, 0-1 loss). I have followed a tutorial but when I call the predict function within my custom loss function, I get the error: 'Undefined function 'predict' for input arguments of type 'nnet.cnn.layer.Layer'.'
I can successfully train the network with trainNetwork, so I am wondering what trainNetwork is doing that I am missing. If I call predict after training with trainNetwork, it works, but not in the training loop.
my input is a 50x1 sequence that is either classified as 1 or 2 (depending on if its average is positive or negative).
My network is defined as follows:
numFeatures = 1; %input data value (50 time points in sequence)
numHiddenUnits = 100;
numClasses = 2; %Left/rigth decision at the end
layers = [ ...
sequenceInputLayer(numFeatures)
lstmLayer(numHiddenUnits,"OutputMode","last")
fullyConnectedLayer(numClasses)
softmaxLayer
classificationLayer];
net = layers;
and my custom loss function is:
function [gradients,state,loss] = customGradients2Lay(net,dlX,Ylabel)
[Y,state]=predict(net,dlX);
loss=loss01(Y,Ylabel);
gradients=dlgradient(loss,net.Learnables);
end
function loss = loss01(Y, T)
if isequal(Y,T)
loss = 0;
else
loss = 1;
end
end
My training loop just calls a random dataset to test on and sends it to predict. The error again is: Undefined function 'predict' for input arguments of type 'nnet.cnn.layer.Layer'.
I also am not sure why its calling nnet.cnn. I even built a custom fully connected layer to try to get around this and it was still calling nnet.cnn class.
What am I missing?

채택된 답변

James Gross
James Gross 2022년 10월 25일
Hello,
To train your network in a custom training loop, you must specify your network as a dlnetwork.
net = dlnetwork(layers);
You should then be able to train and call predict on your network as desired. For examples of how to train using a custom training loop with a dlnetwork, you can refer to one of the following:
I hope this information helps!

추가 답변 (0개)

카테고리

Help CenterFile Exchange에서 Image Data Workflows에 대해 자세히 알아보기

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!

Translated by