How to retrieve the cell/hidden state of an LSTM layer during training

8 visualizaciones (últimos 30 días)
Hi everyone,
as the title says, I'm trying to extract the cell & hidden state from an LSTM layer after training. Unfortunately, I haven't found a solution for that yet.
Does anyone know, how that works or if it is even possible?
Thanks for any advice!

Respuestas (5)

Da-Ting Lin
Da-Ting Lin el 11 de Feb. de 2020
I also have this question. Hopefully it may be included in an upcoming release?

Haoyuan Ma
Haoyuan Ma el 16 de Mzo. de 2020
I have this question too...
I have tried many times before seeing this page.

Giuseppe Dell'Aversana
Giuseppe Dell'Aversana el 16 de Abr. de 2020
I also have this question.. maybe someone has the answer now?

Yildirim Kocoglu
Yildirim Kocoglu el 10 de En. de 2021
It's a little late but, I had the same question and I came across this: https://www.mathworks.com/help/ident/ug/use-lstm-for-linear-system-identification.html
I haven't tried this yet but, please read this carefully as it may help.
Read the part: Set Network Initial State
It says: As the network performs estimation using a step input from 0 to 1, the states of the LSTM network (cell and hidden states of the LSTM layers) drift toward the correct initial condition. To visualize this, extract the cell and hidden state of the network at every time step using the predictAndUpdateState function.
Here is some code from the documentation which you can try to modify to achieve what you need:
stepMarker = time <= 2;
yhat = zeros(sum(stepMarker),1);
hiddenState = zeros(sum(stepMarker),200); % 200 LSTM units
cellState = zeros(sum(stepMarker),200);
for ntime = 1:sum(stepMarker)
[fourthOrderNet,yhat(ntime)] = predictAndUpdateState(fourthOrderNet,stepSignal(ntime)');
hiddenState(ntime,:) = fourthOrderNet.Layers(2,1).HiddenState;
cellState(ntime,:) = fourthOrderNet.Layers(2,1).CellState;
end
If you have multiple batches you can re-use the same batch in a for loop and just predict on your trained network (feed into the network one batch at a time like this for i=1:batch_size) and if you use net = resetState(net) (if you saved your trained network as 'net') at the very beginning of each prediction in the for loop it resets the states to initial states (which is usually zeros if you did not specify them beforehand). It is the same initial states used during your training so, you should be able to see the hiddenstates and cell states of each time step according to the code provided for each batch.
I personally needed to extract the final states to continue the prediction because I'm working on a forecasting problem.

Sathyseelan Mayilvahanam
Sathyseelan Mayilvahanam el 19 de Sept. de 2022
The above mentioned code created matrices with values zeros when I run it. Kindly provide any solutions or code with complete example data.
  2 comentarios
Yildirim Kocoglu
Yildirim Kocoglu el 19 de Sept. de 2022
At which stage (time step) are you trying to extract the hidden/cell state and what is your purpose in extracting it or what kind of problem are you working on (classification, forecasting or something else?). Have you tried printing the hidden/cell states within the for loop in the code? The code I provided is not complete by the way as I borrowed it from the Matlab documentation as far as I remember (check the link I provided for more details). I don’t have an example I can provide as I moved to a different coding language altogether for a different project. The provided code snippet sets them to be zeros at the beginning and if you were to use resetState(net) within the for loop, that will reset the hidden/cell states to their initial states (initial_states = zeros by default if you did not specify the values yourself at the beginning -in this case the code snippet specifies the hidden state to be zeroes before entering the for loop). The hidden/cell states will get updated as you progress through each time step of a sequence and you should be able to print it out within the for loop.

Iniciar sesión para comentar.

Categorías

Más información sobre Sequence and Numeric Feature Data Workflows en Help Center y File Exchange.

Productos


Versión

R2018b

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!

Translated by