improve LSTM with test data for traffic

조회 수: 2 (최근 30일)
arash rad
arash rad 2022년 10월 8일
Hi everyone
I use a code from github and it's a LSTM algorithm . at first it run but it can't predict traffic flow data in each time step so i use this for loop and if to check if the prediction is not equal to test use the test data in the prediction instead of the predicted data .
for i = 1:Nt
if YPred ~= YTest
YPred(i) == YTest(i);
end
YPred = round(YPred);
thank you for helping me
the whole code is this :
clc;clear all;close all
warning off
flow_data = readtable('zafar_queue.xlsx');
Y = flow_data.nVehContrib;
data = Y';
%
%about 4 hours and 20 minutes for data training
% about 40 minutes for test
numTimeStepsTrain = floor(0.95*numel(data));
dataTrain = data(1:numTimeStepsTrain);
dataTest = data(numTimeStepsTrain+1:end);
% Normalize(Training Data Set)
mu = mean(dataTrain);
sig = std(dataTrain);
dataTrainStandardized = (dataTrain - mu) / sig;
XTrain = dataTrainStandardized(1:end-1);
YTrain = dataTrainStandardized(2:end);
%LSTM Net Architecture Def
numFeatures = 1;
numResponses = 1;
numHiddenUnits = 200;
layers = [ ...
sequenceInputLayer(numFeatures)
lstmLayer(numHiddenUnits)
fullyConnectedLayer(numResponses)
regressionLayer];
options = trainingOptions('adam', ...
'MaxEpochs',500, ...
'GradientThreshold',1, ...
'InitialLearnRate',0.005, ...
'LearnRateSchedule','piecewise', ...
'LearnRateDropPeriod',150, ...
'LearnRateDropFactor',0.25, ...
'Verbose',1, ...
'Plots','training-progress');
%
% Train LSTM Net
net = trainNetwork(XTrain,YTrain,layers,options);
% Normalize flow a value between 0 and 1 (Testing Data Set)
dataTestStandardized = (dataTest - mu) / sig;
XTest = dataTestStandardized(1:end);
net = predictAndUpdateState(net,XTrain);
[net,YPred] = predictAndUpdateState(net,YTrain(end));
%
% Predict as long as the test period
numTimeStepsTest = numel(XTest);
for i = 2:numTimeStepsTest
[net,YPred(:,i)] = predictAndUpdateState(net,YPred(:,i-1),'ExecutionEnvironment','cpu');
end
% RMSE calculation of test data set
YTest = dataTest(1:end);
YTest = (YTest - mu) / sig;
rmse = sqrt(mean((YPred-YTest).^2))
% Denormalize Data
YPred = sig*YPred + mu;
YTest = sig*YTest + mu;
% X Label : collect one minute period
x_data = seconds(flow_data.begin);
x_train = x_data(1:numTimeStepsTrain);
x_train = x_train';
x_pred = x_data(numTimeStepsTrain:numTimeStepsTrain+numTimeStepsTest);
YPred = round(YPred);
Nt = length(YTest);
for i = 1:Nt
if YPred ~= YTest
YPred(i) == YTest(i);
else
YPred()
end
YPred = round(YPred);
% Train + Predict Plot
figure
plot(x_train(1:end),dataTrain(1:end))
hold on
plot(x_pred,[data(numTimeStepsTest) YPred],'.-')
% hold off
xlabel("time")
ylabel("FLow")
title("Forecast")
legend(["Observed" "Forecast"])
% RMSE Plot : Test + Predict Plot
figure
subplot(2,1,1)
plot(YTest)
hold on
plot(YPred,'.-')
hold off
legend(["Observed" "Forecast"])
ylabel("Period of time")
title("Forecast")
subplot(2,1,2)
stem(YPred - YTest)
xlabel("period of time")
ylabel("Error")
title("RMSE = " + rmse)
% Train + Test + Predict Plot
figure
plot(x_data,Y)
hold on
plot(x_pred,[data(numTimeStepsTrain) YPred],'.-')
hold off
xlabel("one-min period")
ylabel("Traffic Flow")
title("Compare Data")
legend(["Raw" "Forecast"])

답변 (0개)

카테고리

Help CenterFile Exchange에서 Sequence and Numeric Feature Data Workflows에 대해 자세히 알아보기

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!

Translated by