How to retrieve the cell/hidden state of an LSTM layer during training

조회 수: 8 (최근 30일)
Hi everyone,
as the title says, I'm trying to extract the cell & hidden state from an LSTM layer after training. Unfortunately, I haven't found a solution for that yet.
Does anyone know, how that works or if it is even possible?
Thanks for any advice!

답변 (5개)

Da-Ting Lin
Da-Ting Lin 2020년 2월 11일
I also have this question. Hopefully it may be included in an upcoming release?

Haoyuan Ma
Haoyuan Ma 2020년 3월 16일
I have this question too...
I have tried many times before seeing this page.

Giuseppe Dell'Aversana
Giuseppe Dell'Aversana 2020년 4월 16일
I also have this question.. maybe someone has the answer now?

Yildirim Kocoglu
Yildirim Kocoglu 2021년 1월 10일
It's a little late but, I had the same question and I came across this: https://www.mathworks.com/help/ident/ug/use-lstm-for-linear-system-identification.html
I haven't tried this yet but, please read this carefully as it may help.
Read the part: Set Network Initial State
It says: As the network performs estimation using a step input from 0 to 1, the states of the LSTM network (cell and hidden states of the LSTM layers) drift toward the correct initial condition. To visualize this, extract the cell and hidden state of the network at every time step using the predictAndUpdateState function.
Here is some code from the documentation which you can try to modify to achieve what you need:
stepMarker = time <= 2;
yhat = zeros(sum(stepMarker),1);
hiddenState = zeros(sum(stepMarker),200); % 200 LSTM units
cellState = zeros(sum(stepMarker),200);
for ntime = 1:sum(stepMarker)
[fourthOrderNet,yhat(ntime)] = predictAndUpdateState(fourthOrderNet,stepSignal(ntime)');
hiddenState(ntime,:) = fourthOrderNet.Layers(2,1).HiddenState;
cellState(ntime,:) = fourthOrderNet.Layers(2,1).CellState;
end
If you have multiple batches you can re-use the same batch in a for loop and just predict on your trained network (feed into the network one batch at a time like this for i=1:batch_size) and if you use net = resetState(net) (if you saved your trained network as 'net') at the very beginning of each prediction in the for loop it resets the states to initial states (which is usually zeros if you did not specify them beforehand). It is the same initial states used during your training so, you should be able to see the hiddenstates and cell states of each time step according to the code provided for each batch.
I personally needed to extract the final states to continue the prediction because I'm working on a forecasting problem.

Sathyseelan Mayilvahanam
Sathyseelan Mayilvahanam 2022년 9월 19일
The above mentioned code created matrices with values zeros when I run it. Kindly provide any solutions or code with complete example data.
  댓글 수: 2
Yildirim Kocoglu
Yildirim Kocoglu 2022년 9월 19일
At which stage (time step) are you trying to extract the hidden/cell state and what is your purpose in extracting it or what kind of problem are you working on (classification, forecasting or something else?). Have you tried printing the hidden/cell states within the for loop in the code? The code I provided is not complete by the way as I borrowed it from the Matlab documentation as far as I remember (check the link I provided for more details). I don’t have an example I can provide as I moved to a different coding language altogether for a different project. The provided code snippet sets them to be zeros at the beginning and if you were to use resetState(net) within the for loop, that will reset the hidden/cell states to their initial states (initial_states = zeros by default if you did not specify the values yourself at the beginning -in this case the code snippet specifies the hidden state to be zeroes before entering the for loop). The hidden/cell states will get updated as you progress through each time step of a sequence and you should be able to print it out within the for loop.
Sathyseelan Mayilvahanam
Sathyseelan Mayilvahanam 2022년 9월 19일
Thanks for the comments. I will check that.

댓글을 달려면 로그인하십시오.

카테고리

Help CenterFile Exchange에서 Sequence and Numeric Feature Data Workflows에 대해 자세히 알아보기

제품


릴리스

R2018b

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!

Translated by