필터 지우기
필터 지우기

What does function predict() in Deep Learning Toolbox do?

조회 수: 7 (최근 30일)
Song Decn
Song Decn 2021년 5월 8일
편집: Vidip 2024년 2월 21일
Hi, I follow the example of this
and made a little modification, namely by not using predict() function but calling predictAndUpdateState() to predict the target one by one.
In this way I get a much worse predition result (brown line) as predict() (yellow line).
Can anyone explain this?
The only different part is
% opt1: pure use feature variables as input
net = resetState(net);
YPred = [];
for i = 1:numel(XTest)
[net, temp] = predictAndUpdateState(net, XTest(:,i), 'ExecutionEnvironment', 'cpu');
YPred(:,i) = cell2mat(temp);
end
y1 = YPred;
Whole codes:
[~,~,data] = xlsread('ET_1.xlsx');
data_mat = cell2mat(data);
XTrain = (data_mat(:,4:8))';
XTrain = num2cell(XTrain,1);
YTrain = (data_mat(:,3))';
YTrain = num2cell(YTrain,1);
%%Define Network Architecture
featureDimension = size(XTrain{1},1);
numResponses = size(YTrain{1},1);
numHiddenUnits = 500;
layers = [ ...
sequenceInputLayer(featureDimension)
lstmLayer(numHiddenUnits,'OutputMode','sequence')
fullyConnectedLayer(500) %%50
dropoutLayer(0.1) %%0.5
fullyConnectedLayer(numResponses)
regressionLayer
];
maxepochs = 500;
miniBatchSize = 1;
options = trainingOptions('adam', ... %%adam
'MaxEpochs',maxepochs, ...
'GradientThreshold',1, ...
'InitialLearnRate',0.005, ...
'LearnRateSchedule','piecewise', ...
'LearnRateDropPeriod',125, ...
'LearnRateDropFactor',0.2, ...
'Verbose',0, ...
'Plots','training-progress');
%%Train the Network
net = trainNetwork(XTrain,YTrain,layers,options);
%% Test the Network
[~,~,data] = xlsread('ET_2.xlsx');
data_mat = cell2mat(data);
XTest = (data_mat(:,4:8))'; XTest = num2cell(XTest,1);
YTest = (data_mat(:,3))'; YTest = num2cell(YTest,1);
% opt1: pure use feature variables as input
net = resetState(net);
YPred = [];
for i = 1:numel(XTest)
[net, temp] = predictAndUpdateState(net, XTest(:,i), 'ExecutionEnvironment', 'cpu');
YPred(:,i) = cell2mat(temp);
end
y1 = YPred;
% opt2: predict()
net = resetState(net);
YPred = predict(net, XTest);
y2 = (cell2mat(YPred)); %have to transpose as plot plots columns
%%
figure; hold all
yRef = (cell2mat(YTest)');
plot(yRef, '-o')
plot(y1, '-x')
plot(y2, '-s')
  댓글 수: 1
Song Decn
Song Decn 2021년 5월 10일
% Opt1:
% yTrain = predict(net, xTrainStandardized);
% yTrain = cell2mat(yTrain);
% Opt2:
% yTrain = [];
% for i = 1:numel(xTrainStandardized)
% [net, tmp] = predictAndUpdateState(net, xTrainStandardized(i));
% yTrain(i) = cell2mat(tmp);
% end
% Opt3:
[net, tmp] = predictAndUpdateState(net, xTrainStandardized);
yTrain = cell2mat(tmp);
these 3 ways to calculate responses give different values? Why?

댓글을 달려면 로그인하십시오.

답변 (1개)

Vidip
Vidip 2024년 2월 21일
편집: Vidip 2024년 2월 21일
The reason you are not getting good results with ‘predictAndUpdateState’ in a loop compared to using ‘predict’ is due to how the LSTM network's state is managed between predictions. The predict function treats each sequence as independent and resets the LSTM state automatically between each prediction, which is appropriate when your test sequences are not temporally related. However, when using ‘predictAndUpdateState’ in a loop without resetting the state after each prediction, the LSTM network's internal state carries over from one prediction to the next.
This means that the network's prediction for each data point is influenced by all the previous data points, which is not suitable if the sequences in ‘XTest’ are supposed to be independent. The accumulation of state information from unrelated sequences can lead to inaccurate predictions, as the network is incorrectly using historical context from separate sequences to make its predictions.
For further information, refer to the documentation link below:

카테고리

Help CenterFile Exchange에서 Sequence and Numeric Feature Data Workflows에 대해 자세히 알아보기

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!

Translated by