Main Content

Create Simple Text Model for Classification

This example shows how to train a simple text classifier on word frequency counts using a bag-of-words model.

You can create a simple classification model which uses word frequency counts as predictors. This example trains a simple classification model to predict the category of factory reports using text descriptions.

Load and Extract Text Data

Load the example data. The file factoryReports.csv contains factory reports, including a text description and categorical labels for each report.

filename = "factoryReports.csv";
data = readtable(filename,'TextType','string');
head(data)
ans=8×5 table
                                 Description                                       Category          Urgency          Resolution         Cost 
    _____________________________________________________________________    ____________________    ________    ____________________    _____

    "Items are occasionally getting stuck in the scanner spools."            "Mechanical Failure"    "Medium"    "Readjust Machine"         45
    "Loud rattling and banging sounds are coming from assembler pistons."    "Mechanical Failure"    "Medium"    "Readjust Machine"         35
    "There are cuts to the power when starting the plant."                   "Electronic Failure"    "High"      "Full Replacement"      16200
    "Fried capacitors in the assembler."                                     "Electronic Failure"    "High"      "Replace Components"      352
    "Mixer tripped the fuses."                                               "Electronic Failure"    "Low"       "Add to Watch List"        55
    "Burst pipe in the constructing agent is spraying coolant."              "Leak"                  "High"      "Replace Components"      371
    "A fuse is blown in the mixer."                                          "Electronic Failure"    "Low"       "Replace Components"      441
    "Things continue to tumble off of the belt."                             "Mechanical Failure"    "Low"       "Readjust Machine"         38

Convert the labels in the Category column of the table to categorical and view the distribution of the classes in the data using a histogram.

data.Category = categorical(data.Category);
figure
histogram(data.Category)
xlabel("Class")
ylabel("Frequency")
title("Class Distribution")

Partition the data into a training partition and a held-out test set. Specify the holdout percentage to be 10%.

cvp = cvpartition(data.Category,'Holdout',0.1);
dataTrain = data(cvp.training,:);
dataTest = data(cvp.test,:);

Extract the text data and labels from the tables.

textDataTrain = dataTrain.Description;
textDataTest = dataTest.Description;
YTrain = dataTrain.Category;
YTest = dataTest.Category;

Prepare Text Data for Analysis

Create a function which tokenizes and preprocesses the text data so it can be used for analysis. The function preprocessText, performs the following steps in order:

  1. Tokenize the text using tokenizedDocument.

  2. Remove a list of stop words (such as "and", "of", and "the") using removeStopWords.

  3. Lemmatize the words using normalizeWords.

  4. Erase punctuation using erasePunctuation.

  5. Remove words with 2 or fewer characters using removeShortWords.

  6. Remove words with 15 or more characters using removeLongWords.

Use the example preprocessing function preprocessText to prepare the text data.

documents = preprocessText(textDataTrain);
documents(1:5)
ans = 
  5×1 tokenizedDocument:

    6 tokens: items occasionally get stuck scanner spool
    7 tokens: loud rattle bang sound come assembler piston
    4 tokens: cut power start plant
    3 tokens: fry capacitor assembler
    3 tokens: mixer trip fuse

Create a bag-of-words model from the tokenized documents.

bag = bagOfWords(documents)
bag = 
  bagOfWords with properties:

          Counts: [432×336 double]
      Vocabulary: [1×336 string]
        NumWords: 336
    NumDocuments: 432

Remove words from the bag-of-words model that do not appear more than two times in total. Remove any documents containing no words from the bag-of-words model, and remove the corresponding entries in labels.

bag = removeInfrequentWords(bag,2);
[bag,idx] = removeEmptyDocuments(bag);
YTrain(idx) = [];
bag
bag = 
  bagOfWords with properties:

          Counts: [432×155 double]
      Vocabulary: [1×155 string]
        NumWords: 155
    NumDocuments: 432

Train Supervised Classifier

Train a supervised classification model using the word frequency counts from the bag-of-words model and the labels.

Train a multiclass linear classification model using fitcecoc. Specify the Counts property of the bag-of-words model to be the predictors, and the event type labels to be the response. Specify the learners to be linear. These learners support sparse data input.

XTrain = bag.Counts;
mdl = fitcecoc(XTrain,YTrain,'Learners','linear')
mdl = 
  CompactClassificationECOC
      ResponseName: 'Y'
        ClassNames: [Electronic Failure    Leak    Mechanical Failure    Software Failure]
    ScoreTransform: 'none'
    BinaryLearners: {6×1 cell}
      CodingMatrix: [4×6 double]


  Properties, Methods

For a better fit, you can try specifying different parameters of the linear learners. For more information on linear classification learner templates, see templateLinear.

Test Classifier

Predict the labels of the test data using the trained model and calculate the classification accuracy. The classification accuracy is the proportion of the labels that the model predicts correctly.

Preprocess the test data using the same preprocessing steps as the training data. Encode the resulting test documents as a matrix of word frequency counts according to the bag-of-words model.

documentsTest = preprocessText(textDataTest);
XTest = encode(bag,documentsTest);

Predict the labels of the test data using the trained model and calculate the classification accuracy.

YPred = predict(mdl,XTest);
acc = sum(YPred == YTest)/numel(YTest)
acc = 0.8542

Predict Using New Data

Classify the event type of new factory reports. Create a string array containing the new factory reports.

str = [
    "Coolant is pooling underneath sorter."
    "Sorter blows fuses at start up."
    "There are some very loud rattling sounds coming from the assembler."];
documentsNew = preprocessText(str);
XNew = encode(bag,documentsNew);
labelsNew = predict(mdl,XNew)
labelsNew = 3×1 categorical
     Leak 
     Electronic Failure 
     Mechanical Failure 

Example Preprocessing Function

The function preprocessText, performs the following steps in order:

  1. Tokenize the text using tokenizedDocument.

  2. Remove a list of stop words (such as "and", "of", and "the") using removeStopWords.

  3. Lemmatize the words using normalizeWords.

  4. Erase punctuation using erasePunctuation.

  5. Remove words with 2 or fewer characters using removeShortWords.

  6. Remove words with 15 or more characters using removeLongWords.

function documents = preprocessText(textData)

% Tokenize the text.
documents = tokenizedDocument(textData);

% Remove a list of stop words then lemmatize the words. To improve
% lemmatization, first use addPartOfSpeechDetails.
documents = addPartOfSpeechDetails(documents);
documents = removeStopWords(documents);
documents = normalizeWords(documents,'Style','lemma');

% Erase punctuation.
documents = erasePunctuation(documents);

% Remove words with 2 or fewer characters, and words with 15 or more
% characters.
documents = removeShortWords(documents,2);
documents = removeLongWords(documents,15);

end

See Also

| | | | | | | | |

Related Topics