Main Content

resetState

Reset state parameters of neural network

Description

netUpdated = resetState(net) resets the state parameters of a neural network. Use this function to reset the state of a recurrent neural network such as an LSTM network.

example

Examples

collapse all

Reset the network state between sequence predictions.

Load dlnetJapaneseVowels, a pretrained long short-term memory (LSTM) network trained on the Japanese Vowels data set as described in [1] and [2]. This network was trained on the sequences sorted by sequence length with a mini-batch size of 27.

load dlnetJapaneseVowels

View the network architecture.

net.Layers
ans = 
  4x1 Layer array with layers:

     1   'sequenceinput'   Sequence Input    Sequence input with 12 dimensions
     2   'lstm'            LSTM              LSTM with 100 hidden units
     3   'fc'              Fully Connected   9 fully connected layer
     4   'softmax'         Softmax           softmax

Load the test data.

load JapaneseVowelsTestData

Classify a sequence and update the network state.

X = XTest{94};
[scores,state] = predict(net,X,InputDataFormats="CT");
net.State = state;
label = scores2label(scores,classNames)
label = categorical
     3 

Classify another sequence using the updated network.

X = XTest{1};
scores = predict(net,X,InputDataFormats="CT");
label = scores2label(scores,classNames)
label = categorical
     7 

Compare the final prediction with the true label.

trueLabel = TTest(1)
trueLabel = categorical
     1 

The updated state of the network may have negatively influenced the classification. Reset the network state and predict on the sequence again.

net = resetState(net);
scores = predict(net,X,InputDataFormats="CT");
label = scores2label(scores,classNames)
label = categorical
     1 

Input Arguments

collapse all

Neural network, specified as a dlnetwork object.

The resetState function only has an effect if net has state parameters (for example, a network with at least one recurrent layer such as an LSTM layer). If the input network does not have state parameters, then the function has no effect and returns the input network.

Output Arguments

collapse all

Updated network, returned as a dlnetwork object.

The resetState function only has an effect if net has state parameters (for example, a network with at least one recurrent layer such as an LSTM layer). If the input network does not have state parameters, then the function has no effect and returns the input network.

References

[1] M. Kudo, J. Toyama, and M. Shimbo. "Multidimensional Curve Classification Using Passing-Through Regions." Pattern Recognition Letters. Vol. 20, No. 11–13, pages 1103–1111.

[2] UCI Machine Learning Repository: Japanese Vowels Dataset. https://archive.ics.uci.edu/ml/datasets/Japanese+Vowels

Extended Capabilities

Version History

Introduced in R2017b

expand all