Initial State Dynamical System LSTM Network

조회 수: 3 (최근 30일)
Michael Hesse
Michael Hesse 2020년 11월 18일
댓글: Michael Hesse 2020년 11월 19일
%% - cleanup
clear;
close all;
clc;
%% - data
t = linspace(0, 5, 1000);
odefcn = @(t, x) [x(2, :); 10*sin(x(1, :))-x(2, :)];
x0 = [pi/2, 0]';
[~, x] = ode45(odefcn, t, x0);
x = x';
X = x(:, 1:end-1);
Y = x(:, 2:end);
%% - define and train lstm network
numFeatures = 2;
numResponses = 2;
numHiddenUnits = 200;
layers = [sequenceInputLayer(numFeatures);
lstmLayer(numHiddenUnits);
fullyConnectedLayer(numResponses);
regressionLayer];
opts = trainingOptions('adam', 'MaxEpochs', 100, 'Plots', 'training-progress');
net = trainNetwork(X, Y, layers, opts);
%% - prediction
net = resetState(net);
xpred = x0;
for i = 1 : length(t)-1
[net, xpred(:, i+1)] = predictAndUpdateState(net, xpred(:, i));
end
%% - plotting
figure(1);
plot(t, x);
hold on;
grid on;
plot(t, xpred, '--');
This is an example code where I want to predict the trajectory of a pendulum via LSTM neural network. How can I provide the initial state x0 into the network? If you look at the figure the second state directly jumps from x0 to the state [0, 0]'. Why does this happen?
  댓글 수: 1
Michael Hesse
Michael Hesse 2020년 11월 19일
Here is a possible workaround. Instead of learning the next state, one can learn the difference to the next state.
%% - cleanup
clear;
close all;
clc;
%% - data
t = linspace(0, 5, 1000);
odefcn = @(t, x) [x(2, :); 10*sin(x(1, :))-x(2, :)];
x0 = [pi/2, 0]';
[~, x] = ode45(odefcn, t, x0);
x = x';
X = x(:, 1:end-1);
Y = x(:, 2:end) - x(:, 1:end-1);
%% - define and train lstm network
numFeatures = 2;
numResponses = 2;
numHiddenUnits = 200;
layers = [sequenceInputLayer(numFeatures, 'Normalization', 'zscore');
lstmLayer(numHiddenUnits);
fullyConnectedLayer(numResponses);
regressionLayer];
opts = trainingOptions('adam', 'MaxEpochs', 100, 'Plots', 'training-progress');
net = trainNetwork(X, Y, layers, opts);
%% - prediction
xpred = x0;
for i = 1 : length(t)-1
[net, dxpred] = predictAndUpdateState(net, xpred(:, i));
xpred(:, i+1) = xpred(:, i) + dxpred;
end
%% - plotting
figure(1);
plot(t, x);
hold on;
grid on;
plot(t, xpred, '--');

댓글을 달려면 로그인하십시오.

답변 (0개)

카테고리

Help CenterFile Exchange에서 Deep Learning Toolbox에 대해 자세히 알아보기

제품


릴리스

R2020b

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!

Translated by