LSTMに対するシーケンスを含む複数データの入力の方法に関して
조회 수: 16 (최근 30일)
이전 댓글 표시
現在、深層学習を使用したsequence-to-sequenceの回帰を行っています。しかしシーケンスデータだけでは学習の情報量としては不足しており、入力に新たなシーケンスでない特徴データを追加しようとしています。
そのためにLSTMに対してシーケンスデータと、その他のデータを同時に入力したいのですが、セル配列として入力するとエラーが起きてしまいます。
具体的には
Layers = [ ...
sequenceInputLayer(3)
lstmLayer(numHiddenUnits,'OutputMode','sequence')
fullyConnectedLayer(1)
myRegressionLayer('mae')
];
options = trainingOptions('adam', ...
'MaxEpochs',40,...
'MiniBatchSize',32,...
'GradientThreshold',1,...
'InitialLearnRate',1e-2, ...
'Verbose',false, ...
'Plots', 'training-progress');
Train1:n*tのdouble配列のシーケンスデータ、特徴次元1(t=10)
Train2:n*2のdouble配列のデータ、特徴次元2
XTrain:Train1とTrain2を合体させたもの、特徴次元3
YTrain:n*tのセル配列のシーケンスデータ(t=10)
ここでTrain1とTrain2を一つのセル配列にまとめて格納するため
XTrain = cell(n, 1);
for i = 1:n
XTrain{i, 1} = {Train1(i, :), Train2(i, :)};
end
[net, ~] = trainNetwork(XTrain, YTrain, Layers, options);
とすると、「無効な学習データです。予測子はシーケンスの N 行 1 列の cell 配列でなければなりません。ここで N はシーケンスの数です。すべてのシーケンスは同じ特徴次元と、少なくとも 1 つのタイム ステップをもたなければなりません。」というエラーが起こります。
XTrainのサイズはN*1となってはいますが、このエラーメッセージの原因としてはセル配列XTrainの中身がセル配列になっているのが原因らしく、
かといって、次のようにdouble配列のまま入力すると
Train = cell(n, 1);
for i = 1:n
XTrain{i, 1} = [Train1(i, :)'; Train2(i, :)'];
end
[net, ~] = trainNetwork(XTrain, YTrain, Layers, options);
とすると、当然ながら
「トレーニング シーケンスの特徴次元は 12 ですが、入力層には特徴次元 3 のシーケンスが必要です。」
というエラーが起き、特徴次元数がシーケンスの長さ+2になってしまい、特徴次元がシーケンス長に対応してしまううえ、Train1とTrain2が混ざって一つのシーケンスになってしまいます。
少々稚拙な書き方になってしまっていると思いますが、何卒ご教授いただければ幸いです。
댓글 수: 0
채택된 답변
Ayush Aniket
2024년 6월 17일
The input format required by the LSTM network in MATLAB for dataset of sequences is a Nx1 cell array where each element is a c-by-s matrix, where c is the number of features of the sequence and s is the sequence length. Refer to the following document link to read about various input formats: https://www.mathworks.com/help/deeplearning/ref/trainnetwork.html#mw_36a68d96-8505-4b8d-b338-44e1efa9cc5e
In your first approach, XTrain becomes a Nx1 cell array wherein each element is again a cell array having 2 elements- a 1x10 matrix and 1x2 matrix - which is not the correct format.
And in your second approach, XTrain becomes a Nx1 cell array having each element as 1x12 matrix - which is in correct format but with number of features being 1 and the sequence length as 12. However, the layers for the network expect a sequence with the number of features as 3.
For your task, where Train1 is your sequence data and Train2 is your non-sequence data, you cannot directly concatenate these as elements within a cell array. LSTM networks in MATLAB expect the input data to be in a specific format, and combining different types of data (sequence and non-sequence) requires careful preprocessing.
Instead, you need to modify the network architecture to accept multiple inputs or transform your non-sequence data into a sequence format that the network can process alongside your existing sequence data.
One way to handle this is by repeating your non-sequence features Train2 at each time step to match the sequence length of Train1 as shown below:
n = size(Train1, 1); % Assuming Train1 is n by t
t = size(Train1, 2); % Number of timesteps
featureDim = size(Train2, 2); % Number of non-sequence features
% Preallocate XTrain
XTrain = cell(n, 1);
% Concatenate Train1 and repeated Train2 for each sample
for i = 1:n
repeatedTrain2 = repmat(Train2(i, :), t, 1); % Repeat Train2 to match timesteps
XTrain{i, 1} = [Train1(i, :)', repeatedTrain2]'; % Concatenate along the second dimension
end
This way, your LSTM network can accept this as part of the sequence input. However, this does not capture the information that you are trying to learn correctly as Train2 is not a sequence data.
The correct approach would be to modify your network architecture to accept multiple inputs, one for the sequence data and another for the non-sequence data. However, this approach requires a custom training loop since trainNetwork does not directly support multiple inputs.
Note - trainNetwork function is not recommended anymore. You can use the trainnet function instead. To read about the input formats for sequence datsets in trainnet function, refer the following link:
추가 답변 (0개)
참고 항목
카테고리
Help Center 및 File Exchange에서 ビッグ データの処理에 대해 자세히 알아보기
Community Treasure Hunt
Find the treasures in MATLAB Central and discover how the community can help you!
Start Hunting!