How to retrieve the cell/hidden state of an LSTM layer during training
8 visualizaciones (últimos 30 días)
Mostrar comentarios más antiguos
Valentin Steininger
el 4 de Jul. de 2019
Comentada: Sathyseelan Mayilvahanam
el 19 de Sept. de 2022
Hi everyone,
as the title says, I'm trying to extract the cell & hidden state from an LSTM layer after training. Unfortunately, I haven't found a solution for that yet.
Does anyone know, how that works or if it is even possible?
Thanks for any advice!
0 comentarios
Respuestas (5)
Da-Ting Lin
el 11 de Feb. de 2020
I also have this question. Hopefully it may be included in an upcoming release?
0 comentarios
Haoyuan Ma
el 16 de Mzo. de 2020
I have this question too...
I have tried many times before seeing this page.
0 comentarios
Giuseppe Dell'Aversana
el 16 de Abr. de 2020
I also have this question.. maybe someone has the answer now?
0 comentarios
Yildirim Kocoglu
el 10 de En. de 2021
It's a little late but, I had the same question and I came across this: https://www.mathworks.com/help/ident/ug/use-lstm-for-linear-system-identification.html
I haven't tried this yet but, please read this carefully as it may help.
Read the part: Set Network Initial State
It says: As the network performs estimation using a step input from 0 to 1, the states of the LSTM network (cell and hidden states of the LSTM layers) drift toward the correct initial condition. To visualize this, extract the cell and hidden state of the network at every time step using the predictAndUpdateState function.
Here is some code from the documentation which you can try to modify to achieve what you need:
stepMarker = time <= 2;
yhat = zeros(sum(stepMarker),1);
hiddenState = zeros(sum(stepMarker),200); % 200 LSTM units
cellState = zeros(sum(stepMarker),200);
for ntime = 1:sum(stepMarker)
[fourthOrderNet,yhat(ntime)] = predictAndUpdateState(fourthOrderNet,stepSignal(ntime)');
hiddenState(ntime,:) = fourthOrderNet.Layers(2,1).HiddenState;
cellState(ntime,:) = fourthOrderNet.Layers(2,1).CellState;
end
If you have multiple batches you can re-use the same batch in a for loop and just predict on your trained network (feed into the network one batch at a time like this for i=1:batch_size) and if you use net = resetState(net) (if you saved your trained network as 'net') at the very beginning of each prediction in the for loop it resets the states to initial states (which is usually zeros if you did not specify them beforehand). It is the same initial states used during your training so, you should be able to see the hiddenstates and cell states of each time step according to the code provided for each batch.
I personally needed to extract the final states to continue the prediction because I'm working on a forecasting problem.
0 comentarios
Sathyseelan Mayilvahanam
el 19 de Sept. de 2022
The above mentioned code created matrices with values zeros when I run it. Kindly provide any solutions or code with complete example data.
2 comentarios
Yildirim Kocoglu
el 19 de Sept. de 2022
At which stage (time step) are you trying to extract the hidden/cell state and what is your purpose in extracting it or what kind of problem are you working on (classification, forecasting or something else?). Have you tried printing the hidden/cell states within the for loop in the code? The code I provided is not complete by the way as I borrowed it from the Matlab documentation as far as I remember (check the link I provided for more details). I don’t have an example I can provide as I moved to a different coding language altogether for a different project. The provided code snippet sets them to be zeros at the beginning and if you were to use resetState(net) within the for loop, that will reset the hidden/cell states to their initial states (initial_states = zeros by default if you did not specify the values yourself at the beginning -in this case the code snippet specifies the hidden state to be zeroes before entering the for loop). The hidden/cell states will get updated as you progress through each time step of a sequence and you should be able to print it out within the for loop.
Ver también
Categorías
Más información sobre Sequence and Numeric Feature Data Workflows en Help Center y File Exchange.
Community Treasure Hunt
Find the treasures in MATLAB Central and discover how the community can help you!
Start Hunting!