This example shows how to classify out-of-memory text data with a deep learning network using a custom mini-batch datastore.
A mini-batch datastore is an implementation of a datastore with support for reading data in batches. You can use a mini-batch datastore as a source of training, validation, test, and prediction data sets for deep learning applications. Use mini-batch datastores to read out-of-memory data or to perform specific preprocessing operations when reading batches of data.
When training the network, the software creates mini-batches of sequences of the same length by padding, truncating, or splitting the input data. The
trainingOptions function provides options to pad and truncate input sequences, however, these options are not well suited for sequences of word vectors. Furthermore, this function does not support padding data in a custom datastore. Instead, you must pad and truncate the sequences manually. If you left-pad and truncate the sequences of word vectors, then the training might improve.
The Classify Text Data Using Deep Learning example manually truncates and pads all the documents to the same length. This process adds lots of padding to very short documents and discards lots of data from very long documents.
Alternatively, to prevent adding too much padding or discarding too much data, create a custom mini-batch datastore that inputs mini-batches into the network. The custom mini-batch datastore
textDatastore.m converts mini-batches of documents to sequences or word indices and left-pads each mini-batch to the length of the longest document in the mini-batch. For sorted data, this datastore can help reduce the amount of padding added to the data since documents are not padded to a fixed length. Similarly, the datastore does not discard any data from the documents.
This example uses the custom mini-batch datastore
textDatastore.m. You can adapt this datastore to your data by customizing the functions. For an example showing how to create your own custom mini-batch datastore, see Develop Custom Mini-Batch Datastore (Deep Learning Toolbox).
textDatastore requires a word embedding to convert documents to sequences of vectors. Load a pretrained word embedding using
fastTextWordEmbedding. This function requires Text Analytics Toolbox™ Model for fastText English 16 Billion Token Word Embedding support package. If this support package is not installed, then the function provides a download link.
emb = fastTextWordEmbedding;
Create a datastore that contains the data for training. The custom mini-batch datastore
textDatastore reads predictors and labels from a CSV file. For the predictors, the datastore converts the documents into sequences of word indices and for the responses, the datastore returns a categorical label for each document.
To create the datastore, first save the custom mini-batch datastore
textDatastore.m to the path. For more information about creating custom mini-batch datastores, seeDevelop Custom Mini-Batch Datastore (Deep Learning Toolbox).
For the training data, specify the CSV file
"factoryReports.csv" and that the text and labels are in the columns
filenameTrain = "factoryReports.csv"; textName = "Description"; labelName = "Category"; dsTrain = textDatastore(filenameTrain,textName,labelName,emb)
dsTrain = textDatastore with properties: ClassNames: ["Electronic Failure" "Leak" "Mechanical Failure" "Software Failure"] Datastore: [1×1 matlab.io.datastore.TransformedDatastore] EmbeddingDimension: 300 LabelName: "Category" MiniBatchSize: 128 NumClasses: 4 NumObservations: 480
Define the LSTM network architecture. To input sequence data into the network, include a sequence input layer and set the input size to the embedding dimension. Next, include an LSTM layer with 180 hidden units. To use the LSTM layer for a sequence-to-label classification problem, set the output mode to
'last'. Finally, add a fully connected layer with output size equal to the number of classes, a softmax layer, and a classification layer.
numFeatures = dsTrain.EmbeddingDimension; numHiddenUnits = 180; numClasses = dsTrain.NumClasses; layers = [ ... sequenceInputLayer(numFeatures) lstmLayer(numHiddenUnits,'OutputMode','last') fullyConnectedLayer(numClasses) softmaxLayer classificationLayer];
Specify the training options. Specify the solver to be
'adam' and the gradient threshold to be 2. The datastore
textDatastore.m does not support shuffling, so set
'never'. For an example showing how to implement a datastore with support for shuffling, see Develop Custom Mini-Batch Datastore (Deep Learning Toolbox). To monitor the training progress, set the
'Plots' option to
'training-progress'. To suppress verbose output, set
trainNetwork uses a GPU if one is available. To specify the execution environment manually, use the
'ExecutionEnvironment' name-value pair argument of
trainingOptions. Training on a CPU can take significantly longer than training on a GPU. Training using a GPU requires Parallel Computing Toolbox™ and a supported GPU device. For information on supported devices, see .
miniBatchSize = 128; numObservations = dsTrain.NumObservations; numIterationsPerEpoch = floor(numObservations / miniBatchSize); options = trainingOptions('adam', ... 'MiniBatchSize',miniBatchSize, ... 'GradientThreshold',2, ... 'Shuffle','never', ... 'Plots','training-progress', ... 'Verbose',false);
Train the LSTM network using the
net = trainNetwork(dsTrain,layers,options);
Classify the event type of three new reports. Create a string array containing the new reports.
reportsNew = [ "Coolant is pooling underneath sorter." "Sorter blows fuses at start up." "There are some very loud rattling sounds coming from the assembler."];
Preprocess the text data using the preprocessing steps as the datastore
documents = tokenizedDocument(reportsNew); documents = lower(documents); documents = erasePunctuation(documents); predictors = doc2sequence(emb,documents);
Classify the new sequences using the trained LSTM network.
labelsNew = classify(net,predictors)
labelsNew = 3×1 categorical Leak Electronic Failure Mechanical Failure
lstmLayer (Deep Learning Toolbox) |
sequenceInputLayer (Deep Learning Toolbox) |
trainingOptions (Deep Learning Toolbox) |
trainNetwork (Deep Learning Toolbox)