필터 지우기
필터 지우기

I got different outputs from the trained network

조회 수: 35 (최근 30일)
peng yu
peng yu 2024년 7월 10일 13:06
댓글: peng yu 2024년 7월 14일 12:13
Hi all, I already trained a LSTM network and use it to classify the testset. However, the outputs are different when I input the testset samples one by one through for loop and input it as an array. Below is the code:
% Xtest is a 81-1 vector.
% case1: one by one input through for loop
for i = 1:81
testPred_single(i) = classify(LSTM_net,Xtest(i),'SequenceLength','longest');
end
% case2: array input
testPred=classify(LSTM_net,Xtest,'SequenceLength','longest');
Below is the part element of the output variables testPred_single and testPred.
Could anyone explain what causes the gap between this two output variables? Thanks.
  댓글 수: 2
Aquatris
Aquatris 2024년 7월 10일 13:30
I am by no means an expert but my understanding is, as per definition of LSTM, they are not good when the input data is not a sequence. When you give the inputs individually, you basically remove the sequence information. Hence it comes up with a different output.
peng yu
peng yu 2024년 7월 11일 14:46
Thanks for your explain. To verify this statement, I also tried the matlab example (Sequence Classification Using 1-D Convolutions), and this problem also happened when I used for loop to input the testset.
openExample('nnet/SequenceClassificationUsing1DConvolutionsExample')
% my for loop
for i = 1:length(XValidation)
YPred_single(i) = classify(net,XValidation(i), ...
MiniBatchSize=miniBatchSize, ...
SequencePaddingDirection="left");
end
YPred_single = YPred_single';
% MATLAB example code
YPred = classify(net,XValidation, ...
MiniBatchSize=miniBatchSize, ...
SequencePaddingDirection="left");
Below is the details of variable TPred and YPred_single.
It seems like the 1D CNN also leads to this problem not LSTM only. So do you think the !D CNN also predicts badly when the input data is a single sample?

댓글을 달려면 로그인하십시오.

채택된 답변

Antoni Woss
Antoni Woss 2024년 7월 12일 13:05
편집: Antoni Woss 2024년 7월 12일 13:06
The differences in the output are coming from the preprocessing applied to your data in the call to minibatchpredict or classify as per the referenced examples. Specifically, the SequencePaddingDirection="left" will append the MiniBatchSize number of inputs with zeros such that the different time dimensions for each observation within the minibatch all have the same total number of time steps. You can find more information about sequence padding on this documentation page: https://uk.mathworks.com/help/deeplearning/ug/long-short-term-memory-networks.html#mw_81a7b85b-51dc-4bd7-9bb9-215f473a956f
As a concrete example, the first two entries of XTest have different time lengths.
XTest(1:2)
ans =
2×1 cell array
{127×3 double}
{180×3 double}
So running the minibatchpredict function with a MiniBatchSize=2 and SequencePaddingDirection="left" will add a 53x3 zero matrix to the first entry of XTest so that both observations are of size 180x3.
Running the minibatchpredict with function with a MiniBatchSize=1 will not do any padding and will call predict on the two sequences through the network separately. Therefore, you would expect a difference in the first batch output of the network for these two cases, but not the second (as no padding ever occurs in the second observation for MiniBatchSize=1 or MiniBatchSize=2 as it is the longest sequence).
scoresMiniBatchSize_1 = minibatchpredict(net,XTest,SequencePaddingDirection="left",MiniBatchSize=1);
scoresMiniBatchSize_2 = minibatchpredict(net,XTest,SequencePaddingDirection="left",MiniBatchSize=2);
scoresMiniBatchSize_1(1:2,:)
ans =
2×4 single matrix
0.0000 0.8725 0.0000 0.1274
1.0000 0.0000 0.0000 0.0000
scoresMiniBatchSize_2(1:2,:)
ans =
2×4 single matrix
0.0000 0.8755 0.0006 0.1239
1.0000 0.0000 0.0000 0.0000
  댓글 수: 1
peng yu
peng yu 2024년 7월 14일 12:13
Dear Antoni, thanks a lot for your useful response and it is really helpful for me. I tried my model after manually padding the training samples into a same length. This time the difference in the outputs disappears. Thank you very much again!

댓글을 달려면 로그인하십시오.

추가 답변 (0개)

카테고리

Help CenterFile Exchange에서 Sequence and Numeric Feature Data Workflows에 대해 자세히 알아보기

제품


릴리스

R2022a

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!

Translated by