predict function not working in custom training loop
조회 수: 10 (최근 30일)
이전 댓글 표시
I am building a custom training loop for a simple LSTM classification network because I need a custom loss function (specifically, 0-1 loss). I have followed a tutorial but when I call the predict function within my custom loss function, I get the error: 'Undefined function 'predict' for input arguments of type 'nnet.cnn.layer.Layer'.'
I can successfully train the network with trainNetwork, so I am wondering what trainNetwork is doing that I am missing. If I call predict after training with trainNetwork, it works, but not in the training loop.
my input is a 50x1 sequence that is either classified as 1 or 2 (depending on if its average is positive or negative).
My network is defined as follows:
numFeatures = 1; %input data value (50 time points in sequence)
numHiddenUnits = 100;
numClasses = 2; %Left/rigth decision at the end
layers = [ ...
sequenceInputLayer(numFeatures)
lstmLayer(numHiddenUnits,"OutputMode","last")
fullyConnectedLayer(numClasses)
softmaxLayer
classificationLayer];
net = layers;
and my custom loss function is:
function [gradients,state,loss] = customGradients2Lay(net,dlX,Ylabel)
[Y,state]=predict(net,dlX);
loss=loss01(Y,Ylabel);
gradients=dlgradient(loss,net.Learnables);
end
function loss = loss01(Y, T)
if isequal(Y,T)
loss = 0;
else
loss = 1;
end
end
My training loop just calls a random dataset to test on and sends it to predict. The error again is: Undefined function 'predict' for input arguments of type 'nnet.cnn.layer.Layer'.
I also am not sure why its calling nnet.cnn. I even built a custom fully connected layer to try to get around this and it was still calling nnet.cnn class.
What am I missing?
댓글 수: 0
채택된 답변
James Gross
2022년 10월 25일
Hello,
net = dlnetwork(layers);
You should then be able to train and call predict on your network as desired. For examples of how to train using a custom training loop with a dlnetwork, you can refer to one of the following:
I hope this information helps!
추가 답변 (0개)
참고 항목
카테고리
Help Center 및 File Exchange에서 Image Data Workflows에 대해 자세히 알아보기
Community Treasure Hunt
Find the treasures in MATLAB Central and discover how the community can help you!
Start Hunting!