Main Content

Code Generation for LSTM Network That Classifies Text Data

This example shows how to generate generic C code for a pretrained long short-term memory (LSTM) network that classifies text data. This example generates a MEX function that makes predictions for each step of an input timeseries. The example demonstrates two approaches. The first approach uses a standard LSTM network. The second approach leverages the stateful behavior of the same LSTM network. This example uses textual descriptions of factory events that can be classified into one of these four categories: Electronic Failure, Leak, Mechanical Failure, and Software Failure. For more information about the pretrained LSTM network, see the Classify Text Data Using Deep Learning (Text Analytics Toolbox).

This example is supported on Mac®, Linux® and Windows® platforms and not supported for MATLAB Online.

Prepare Input

Load the wordEncoding MAT-file. This MAT-file stores the words encoded as numerical indices. This encoding was performed during the training of the network. For more information, see Classify Text Data Using Deep Learning (Text Analytics Toolbox).


Create a string array containing the new reports to classify the event type.

reportsNew = [ ...
    "Coolant is pooling underneath sorter."
    "Sorter blows fuses at start up."
    "There are some very loud rattling sounds coming from the assembler."
    "At times mechanical arrangement software freezes."
    "Mixer output is stuck."];

Tokenize the input string by using the preprocessText function.

documentsNew = preprocessText(reportsNew);

Use the doc2sequence (Text Analytics Toolbox) function to convert documents to sequences.

XNew = doc2sequence(enc,documentsNew);
labels = categorical({'Electronic Failure', 'Leak', 'Mechanical Failure', 'Software Failure'});

The lstm_predict Entry-Point Function

A sequence-to-sequence LSTM network enables you to make different predictions for each individual time step of a data sequence. The lstm_predict.m entry-point function takes an input sequence and passes it to a trained LSTM network for prediction. Specifically, the function uses the LSTM network that is trained in the example Classify Text Data Using Deep Learning (Text Analytics Toolbox). The function loads the network object from the textClassifierNetwork.mat file into a persistent variable and then performs prediction. On subsequent calls, the function reuses the persistent object.

function out = lstm_predict(in)

%   Copyright 2020-2024 The MathWorks, Inc.
    dlIn = dlarray(in,'CT');
    persistent dlnet;

    if isempty(dlnet)
        dlnet = coder.loadDeepLearningNetwork('textClassifierNetwork.mat');

    dlOut = predict(dlnet, dlIn);
    out = extractdata(dlOut);

To display an interactive visualization of the network architecture and information about the network layers, use the analyzeNetwork (Deep Learning Toolbox) function.

Generate MEX

To generate code, create a code configuration object for a MEX target and set the target language to C. Use the coder.DeepLearningConfig function to create a deep learning configuration object that does not depend on third-party libraries. Assign it to the DeepLearningConfig property of the code configuration object.

cfg = coder.config('mex');
cfg.TargetLang = 'C';
cfg.IntegrityChecks = false;
cfg.DeepLearningConfig = coder.DeepLearningConfig(TargetLibrary = 'none');

Use the coder.typeof function to specify the type and size of the input argument to the entry-point function. In this example, the input is of single data type with a feature dimension value of 1 and a variable sequence length.

matrixInput = coder.typeof(single(0),[1 Inf],[false true]);

Generate a MEX function by running the codegen command.

codegen -config cfg lstm_predict -args {matrixInput} -report
Code generation successful: View report

Run Generated MEX

Call lstm_predict_mex on the first observation.

YPred1 = lstm_predict_mex(single(XNew{1}));

YPred1 contains the probabilities for the four classes. Find the predicted class by calculating the index of the maximum probability.

[~, maxIndex] = max(YPred1);

Associate the indices of max probability to the corresponding label. Display the classification. From the results, you can see that the network predicted the first event to be a Leak.

predictedLabels1 = labels(maxIndex);

Generate MEX with Stateful LSTM

Instead of passing the entire timeseries to predict in one step, you can run prediction on an input by streaming in one timestep at a time by updating the state of the dlnetwork. The predict (Deep Learning Toolbox) function allows you to produce the output prediction, along with the updated network state. This lstm_predict_and_update function takes in a single-timestep input and updates the state of the network so that subsequent inputs are treated as subsequent timesteps of the same sample. After passing in all timesteps one at a time, the resulting output is the same as if all timesteps were passed in as a single input.

function out = lstm_predict_and_update(in)

%   Copyright 2020-2024 The MathWorks, Inc.
    dlIn = dlarray(in,'CT');
    persistent dlnet;

    if isempty(dlnet)
        dlnet = coder.loadDeepLearningNetwork('textClassifierNetwork.mat');

    [dlOut, updatedState] = predict(dlnet, dlIn);
    dlnet.State = updatedState;
    out = extractdata(dlOut);

Generate code for lstm_predict_and_update. Because this function accepts a single timestep at each call, specify matrixInput to have a fixed sequence dimension of 1 instead of a variable sequence length.

matrixInput = coder.typeof(single(0),[1 1]);
codegen -config cfg lstm_predict_and_update -args {matrixInput} -report
Code generation successful: View report

Run the generated MEX on the first observation.

sequenceLength = size(XNew{1},2);
for i=1:sequenceLength
    inTimeStep = XNew{1}(:,i);
    YPred3 = lstm_predict_and_update_mex(single(inTimeStep));
clear lstm_predict_and_update_mex;

Find the index that has the highest probability and map it to the labels.

[~, maxIndex] = max(YPred3);
predictedLabels3 = labels(maxIndex);

See Also

| (Text Analytics Toolbox) | |

Related Topics