# Train Generalized Additive Model for Binary Classification

This example shows how to train a generalized additive model (GAM) with optimal parameters and how to assess the predictive performance of the trained model. The example first finds the optimal parameter values for a univariate GAM (parameters for linear terms) and then finds the values for a bivariate GAM (parameters for interaction terms). Also, the example explains how to interpret the trained model by examining local effects of terms on a specific prediction and by computing the partial dependence of the predictions on predictors.

Load the 1994 census data stored in `census1994.mat`. The data set consists of demographic data from the US Census Bureau to predict whether an individual makes over \$50,000 per year. The classification task is to fit a model that predicts the salary category of people given their age, working class, education level, marital status, race, and so on.

`load census1994`

`census1994` contains the training data set `adultdata` and the test data set `adulttest`. To reduce the running time for this example, subsample 500 training observations and 500 test observations by using the `datasample` function.

```rng('default') NumSamples = 5e2; adultdata = datasample(adultdata,NumSamples,'Replace',false); adulttest = datasample(adulttest,NumSamples,'Replace',false);```

### Find Optimal Parameters for Univariate GAM

Optimize the parameters for a univariate GAM with respect to cross-validation by using the `bayesopt` function.

Prepare `optimizableVariable` objects for the name-value arguments of a univariate GAM: `MaxNumSplitsPerPredictor`, `NumTreesPerPredictor`, and `InitialLearnRateForPredictors`.

```maxNumSplitsPerPredictor = optimizableVariable('maxNumSplitsPerPredictor',[1,10],'Type','integer'); numTreesPerPredictor = optimizableVariable('numTreesPerPredictor',[1,500],'Type','integer'); initialLearnRateForPredictors = optimizableVariable('initialLearnRateForPredictors',[1e-3,1],'Type','real');```

Create an objective function that takes an input `z = [maxNumSplitsPerPredictor,numTreesPerPredictor,initialLearnRateForPredictors]` and returns the cross-validated loss value at the parameters in `z`.

```minfun1 = @(z)kfoldLoss(fitcgam(adultdata,'salary','Weights','fnlwgt', ... 'CrossVal','on', ... 'InitialLearnRateForPredictors',z.initialLearnRateForPredictors, ... 'MaxNumSplitsPerPredictor',z.maxNumSplitsPerPredictor, ... 'NumTreesPerPredictor',z.numTreesPerPredictor));```

If you specify the cross-validation option `'CrossVal','on'`, then the `fitcgam` function returns a cross-validated model object `ClassificationPartitionedGAM`. The `kfoldLoss` function returns the classification loss obtained by the cross-validated model. Therefore, the function handle `minfun` computes the cross-validation loss at the parameters in `z`.

Search for the best parameters using `bayesopt`. For reproducibility, choose the `'expected-improvement-plus'` acquisition function. The default acquisition function depends on run time and, therefore, can give varying results.

```results1 = bayesopt(minfun1, ... [initialLearnRateForPredictors,maxNumSplitsPerPredictor,numTreesPerPredictor], ... 'IsObjectiveDeterministic',true, ... 'AcquisitionFunctionName','expected-improvement-plus');```
```|====================================================================================================================| | Iter | Eval | Objective | Objective | BestSoFar | BestSoFar | initialLearn-| maxNumSplits-| numTreesPerP-| | | result | | runtime | (observed) | (estim.) | RateForPredi | PerPredictor | redictor | |====================================================================================================================| | 1 | Best | 0.18549 | 5.6957 | 0.18549 | 0.18549 | 0.73503 | 7 | 99 | | 2 | Accept | 0.19145 | 20.383 | 0.18549 | 0.18549 | 0.72917 | 10 | 399 | | 3 | Best | 0.17703 | 13.412 | 0.17703 | 0.17703 | 0.079299 | 8 | 267 | | 4 | Best | 0.14955 | 0.402 | 0.14955 | 0.14955 | 0.24236 | 4 | 3 | | 5 | Accept | 0.15999 | 12.363 | 0.14955 | 0.14955 | 0.25509 | 1 | 377 | | 6 | Accept | 0.15158 | 1.5035 | 0.14955 | 0.14955 | 0.23051 | 7 | 29 | | 7 | Accept | 0.16181 | 0.18204 | 0.14955 | 0.14955 | 0.34396 | 4 | 1 | | 8 | Accept | 0.15079 | 0.38418 | 0.14955 | 0.14955 | 0.26669 | 10 | 5 | | 9 | Accept | 0.16102 | 0.55525 | 0.14955 | 0.14955 | 0.26065 | 2 | 10 | | 10 | Accept | 0.19259 | 8.6487 | 0.14955 | 0.14955 | 0.24894 | 10 | 182 | | 11 | Accept | 0.18628 | 0.20681 | 0.14955 | 0.14955 | 0.13389 | 6 | 2 | | 12 | Accept | 0.15653 | 0.24643 | 0.14955 | 0.14955 | 0.24172 | 10 | 2 | | 13 | Best | 0.14699 | 0.82743 | 0.14699 | 0.14699 | 0.26745 | 7 | 12 | | 14 | Best | 0.14634 | 0.47528 | 0.14634 | 0.14634 | 0.25025 | 6 | 6 | | 15 | Best | 0.14312 | 0.34493 | 0.14312 | 0.14312 | 0.30452 | 9 | 3 | | 16 | Accept | 0.14334 | 0.51583 | 0.14312 | 0.14312 | 0.33507 | 10 | 7 | | 17 | Best | 0.13791 | 0.32248 | 0.13791 | 0.13791 | 0.33179 | 9 | 4 | | 18 | Accept | 0.14875 | 0.3551 | 0.13791 | 0.13791 | 0.36806 | 8 | 5 | | 19 | Accept | 0.1651 | 1.3731 | 0.13791 | 0.13791 | 0.32691 | 8 | 27 | | 20 | Accept | 0.15895 | 0.37324 | 0.13791 | 0.13791 | 0.32985 | 7 | 5 | |====================================================================================================================| | Iter | Eval | Objective | Objective | BestSoFar | BestSoFar | initialLearn-| maxNumSplits-| numTreesPerP-| | | result | | runtime | (observed) | (estim.) | RateForPredi | PerPredictor | redictor | |====================================================================================================================| | 21 | Accept | 0.13946 | 0.26793 | 0.13791 | 0.13791 | 0.36721 | 9 | 3 | | 22 | Accept | 0.16719 | 1.1276 | 0.13791 | 0.13791 | 0.25385 | 5 | 23 | | 23 | Accept | 0.17017 | 1.35 | 0.13791 | 0.13791 | 0.23809 | 9 | 26 | | 24 | Accept | 0.15519 | 0.46246 | 0.13791 | 0.13791 | 0.34831 | 9 | 7 | | 25 | Accept | 0.15312 | 0.26445 | 0.13791 | 0.13791 | 0.33416 | 10 | 3 | | 26 | Accept | 0.15852 | 0.31045 | 0.13791 | 0.13791 | 0.6142 | 9 | 4 | | 27 | Accept | 0.16691 | 0.50559 | 0.13791 | 0.13791 | 0.31446 | 5 | 7 | | 28 | Accept | 0.14384 | 0.35136 | 0.13791 | 0.13791 | 0.40215 | 9 | 4 | | 29 | Accept | 0.14773 | 0.33296 | 0.13791 | 0.13791 | 0.34255 | 9 | 4 | | 30 | Accept | 0.17604 | 0.85847 | 0.13791 | 0.13791 | 0.36565 | 6 | 15 | ```

```__________________________________________________________ Optimization completed. MaxObjectiveEvaluations of 30 reached. Total function evaluations: 30 Total elapsed time: 97.6656 seconds Total objective function evaluation time: 74.4022 Best observed feasible point: initialLearnRateForPredictors maxNumSplitsPerPredictor numTreesPerPredictor _____________________________ ________________________ ____________________ 0.33179 9 4 Observed objective function value = 0.13791 Estimated objective function value = 0.13791 Function evaluation time = 0.32248 Best estimated feasible point (according to models): initialLearnRateForPredictors maxNumSplitsPerPredictor numTreesPerPredictor _____________________________ ________________________ ____________________ 0.33179 9 4 Estimated objective function value = 0.13791 Estimated function evaluation time = 0.33084 ```

Obtain the best point from `results1`.

`zbest1 = bestPoint(results1)`
```zbest1=1×3 table initialLearnRateForPredictors maxNumSplitsPerPredictor numTreesPerPredictor _____________________________ ________________________ ____________________ 0.33179 9 4 ```

### Train Univariate GAM with Optimal Parameters

Train an optimized GAM using the `zbest1` values. A recommended practice is to specify the class names.

```Mdl1 = fitcgam(adultdata,'salary','Weights','fnlwgt', ... 'ClassNames',categorical({'<=50K','>50K'}), ... 'InitialLearnRateForPredictors',zbest1.initialLearnRateForPredictors, ... 'MaxNumSplitsPerPredictor',zbest1.maxNumSplitsPerPredictor, ... 'NumTreesPerPredictor',zbest1.numTreesPerPredictor) ```
```Mdl1 = ClassificationGAM PredictorNames: {'age' 'workClass' 'education' 'education_num' 'marital_status' 'occupation' 'relationship' 'race' 'sex' 'capital_gain' 'capital_loss' 'hours_per_week' 'native_country'} ResponseName: 'salary' CategoricalPredictors: [2 3 5 6 7 8 9 13] ClassNames: [<=50K >50K] ScoreTransform: 'logit' Intercept: -1.7383 NumObservations: 500 Properties, Methods ```

`Mdl1` is a `ClassificationGAM` model object. The model display shows a partial list of the model properties. To view the full list of the model properties, double-click the variable name `Mdl1` in the Workspace. The Variables editor opens for `Mdl1`. Alternatively, you can display the properties in the Command Window by using dot notation. For example, display the `ReasonForTermination` property.

`Mdl1.ReasonForTermination`
```ans = struct with fields: PredictorTrees: 'Terminated after training the requested number of trees.' InteractionTrees: '' ```

The `PredictorTrees` field of the property value indicates that `Mdl1` includes the specified number of trees. `NumTreesPerPredictor` of `fitcgam` specifies the maximum number of trees per predictor, and the function can stop before training the requested number of trees. You can use the `ReasonForTermination` property to determine whether the trained model contains the specified number of trees.

If you specify to include interaction terms so that `fitcgam` trains trees for them, then the `InteractionTrees` field contains a nonempty value.

### Find Optimal Parameters for Bivariate GAM

Find the parameters for interaction terms of a bivariate GAM by using the `bayesopt` function.

Prepare `optimizableVariable` objects for the name-value arguments for the interaction terms: `InitialLearnRateForInteractions`, `MaxNumSplitsPerInteraction`, `NumTreesPerInteraction`, and `InitialLearnRateForInteractions`.

```initialLearnRateForInteractions = optimizableVariable('initialLearnRateForInteractions',[1e-3,1],'Type','real'); maxNumSplitsPerInteraction = optimizableVariable('maxNumSplitsPerInteraction',[1,10],'Type','integer'); numTreesPerInteraction = optimizableVariable('numTreesPerInteraction',[1,500],'Type','integer'); numInteractions = optimizableVariable('numInteractions',[1,28],'Type','integer');```

Create an objective function for the optimization. Use the optimal parameter values in `zbest1` so that the software finds optimal parameter values for interaction terms based on the `zbest1` values.

```minfun2 = @(z)kfoldLoss(fitcgam(adultdata,'salary','Weights','fnlwgt', ... 'CrossVal','on', ... 'InitialLearnRateForPredictors',zbest1.initialLearnRateForPredictors, ... 'MaxNumSplitsPerPredictor',zbest1.maxNumSplitsPerPredictor, ... 'NumTreesPerPredictor',zbest1.numTreesPerPredictor, ... 'InitialLearnRateForInteractions',z.initialLearnRateForInteractions, ... 'MaxNumSplitsPerInteraction',z.maxNumSplitsPerInteraction, ... 'NumTreesPerInteraction',z.numTreesPerInteraction, ... 'Interactions',z.numInteractions));```

Search for the best parameters using `bayesopt`. The optimization process trains multiple models and displays warning messages if the models include no interaction terms. Disable all warnings before calling `bayesopt` and restore the warning state after running `bayesopt`. You can leave the warning state unchanged to view the warning messages.

```orig_state = warning('query'); warning('off') results2 = bayesopt(minfun2, ... [initialLearnRateForInteractions,maxNumSplitsPerInteraction,numTreesPerInteraction,numInteractions], ... 'IsObjectiveDeterministic',true, ... 'AcquisitionFunctionName','expected-improvement-plus');```
```|===================================================================================================================================| | Iter | Eval | Objective | Objective | BestSoFar | BestSoFar | initialLearn-| maxNumSplits-| numTreesPerI-| numInteracti-| | | result | | runtime | (observed) | (estim.) | RateForInter | PerInteracti | nteraction | ons | |===================================================================================================================================| | 1 | Best | 0.19671 | 10.999 | 0.19671 | 0.19671 | 0.96444 | 8 | 109 | 22 | | 2 | Best | 0.189 | 30.57 | 0.189 | 0.189 | 0.98548 | 6 | 457 | 17 | | 3 | Best | 0.16538 | 18.643 | 0.16538 | 0.16538 | 0.28678 | 4 | 383 | 13 | | 4 | Best | 0.15243 | 0.4285 | 0.15243 | 0.15243 | 0.28044 | 1 | 45 | 3 | | 5 | Accept | 0.16065 | 0.69005 | 0.15243 | 0.15243 | 0.20151 | 7 | 60 | 1 | | 6 | Best | 0.14831 | 0.36629 | 0.14831 | 0.14831 | 0.032423 | 1 | 151 | 1 | | 7 | Accept | 0.14887 | 0.36443 | 0.14831 | 0.14831 | 0.021093 | 1 | 15 | 1 | | 8 | Accept | 0.15039 | 0.42139 | 0.14831 | 0.14831 | 0.012128 | 2 | 482 | 1 | | 9 | Best | 0.14787 | 0.42482 | 0.14787 | 0.14787 | 0.10119 | 1 | 121 | 6 | | 10 | Best | 0.13902 | 0.38822 | 0.13902 | 0.13902 | 0.1233 | 1 | 281 | 3 | | 11 | Accept | 0.14721 | 0.39532 | 0.13902 | 0.13902 | 0.065618 | 1 | 291 | 3 | | 12 | Accept | 0.14586 | 0.39205 | 0.13902 | 0.13902 | 0.18711 | 1 | 117 | 1 | | 13 | Accept | 0.15073 | 0.383 | 0.13902 | 0.13902 | 0.15072 | 1 | 15 | 3 | | 14 | Accept | 0.14966 | 0.42744 | 0.13902 | 0.13902 | 0.17155 | 1 | 497 | 4 | | 15 | Best | 0.13716 | 0.37599 | 0.13716 | 0.13716 | 0.12601 | 1 | 281 | 1 | | 16 | Accept | 0.15094 | 0.38197 | 0.13716 | 0.13716 | 0.13962 | 2 | 284 | 1 | | 17 | Accept | 0.13972 | 4.5994 | 0.13716 | 0.13716 | 0.0028545 | 5 | 481 | 2 | | 18 | Accept | 0.14788 | 31.639 | 0.13716 | 0.13716 | 0.0024433 | 6 | 489 | 15 | | 19 | Accept | 0.14565 | 1.276 | 0.13716 | 0.13716 | 0.013118 | 5 | 257 | 1 | | 20 | Accept | 0.16502 | 28.315 | 0.13716 | 0.13716 | 0.0063353 | 4 | 457 | 16 | |===================================================================================================================================| | Iter | Eval | Objective | Objective | BestSoFar | BestSoFar | initialLearn-| maxNumSplits-| numTreesPerI-| numInteracti-| | | result | | runtime | (observed) | (estim.) | RateForInter | PerInteracti | nteraction | ons | |===================================================================================================================================| | 21 | Accept | 0.15693 | 4.9653 | 0.13716 | 0.13716 | 0.016486 | 6 | 466 | 2 | | 22 | Accept | 0.16312 | 29.942 | 0.13716 | 0.13716 | 0.019904 | 5 | 488 | 15 | | 23 | Accept | 0.15719 | 4.7423 | 0.13716 | 0.13716 | 0.020155 | 4 | 456 | 3 | | 24 | Best | 0.129 | 6.4419 | 0.129 | 0.129 | 0.090858 | 5 | 478 | 3 | | 25 | Accept | 0.15118 | 6.6757 | 0.129 | 0.129 | 0.15943 | 5 | 494 | 3 | | 26 | Accept | 0.15343 | 2.2035 | 0.129 | 0.129 | 0.070349 | 5 | 489 | 1 | | 27 | Best | 0.12879 | 6.8017 | 0.12879 | 0.12879 | 0.091985 | 5 | 387 | 4 | | 28 | Accept | 0.19093 | 5.9262 | 0.12879 | 0.12879 | 0.067405 | 5 | 331 | 4 | | 29 | Accept | 0.16767 | 6.3779 | 0.12879 | 0.12879 | 0.31419 | 5 | 472 | 3 | | 30 | Accept | 0.17636 | 11.026 | 0.12879 | 0.12879 | 0.054697 | 5 | 383 | 7 | ```

```__________________________________________________________ Optimization completed. MaxObjectiveEvaluations of 30 reached. Total function evaluations: 30 Total elapsed time: 239.1035 seconds Total objective function evaluation time: 216.5833 Best observed feasible point: initialLearnRateForInteractions maxNumSplitsPerInteraction numTreesPerInteraction numInteractions _______________________________ __________________________ ______________________ _______________ 0.091985 5 387 4 Observed objective function value = 0.12879 Estimated objective function value = 0.12879 Function evaluation time = 6.8017 Best estimated feasible point (according to models): initialLearnRateForInteractions maxNumSplitsPerInteraction numTreesPerInteraction numInteractions _______________________________ __________________________ ______________________ _______________ 0.091985 5 387 4 Estimated objective function value = 0.12879 Estimated function evaluation time = 6.7245 ```
`warning(orig_state)`

Obtain the best point from `results2`.

`zbest2 = bestPoint(results2)`
```zbest2=1×4 table initialLearnRateForInteractions maxNumSplitsPerInteraction numTreesPerInteraction numInteractions _______________________________ __________________________ ______________________ _______________ 0.091985 5 387 4 ```

### Train Bivariate GAM with Optimal Parameters

Train an optimized GAM using the `zbest1` and `zbest2` values.

```Mdl = fitcgam(adultdata,'salary','Weights','fnlwgt', ... 'ClassNames',categorical({'<=50K','>50K'}), ... 'InitialLearnRateForPredictors',zbest1.initialLearnRateForPredictors, ... 'MaxNumSplitsPerPredictor',zbest1.maxNumSplitsPerPredictor, ... 'NumTreesPerPredictor',zbest1.numTreesPerPredictor, ... 'InitialLearnRateForInteractions',zbest2.initialLearnRateForInteractions, ... 'MaxNumSplitsPerInteraction',zbest2.maxNumSplitsPerInteraction, ... 'NumTreesPerInteraction',zbest2.numTreesPerInteraction, ... 'Interactions',zbest2.numInteractions) ```
```Mdl = ClassificationGAM PredictorNames: {'age' 'workClass' 'education' 'education_num' 'marital_status' 'occupation' 'relationship' 'race' 'sex' 'capital_gain' 'capital_loss' 'hours_per_week' 'native_country'} ResponseName: 'salary' CategoricalPredictors: [2 3 5 6 7 8 9 13] ClassNames: [<=50K >50K] ScoreTransform: 'logit' Intercept: -1.7755 Interactions: [4×2 double] NumObservations: 500 Properties, Methods ```

Alternatively, you can add interaction terms to the univariate GAM by using the `addInteractions` function.

```Mdl2 = addInteractions(Mdl1,zbest2.numInteractions, ... 'InitialLearnRateForInteractions',zbest2.initialLearnRateForInteractions, ... 'MaxNumSplitsPerInteraction',zbest2.maxNumSplitsPerInteraction, ... 'NumTreesPerInteraction',zbest2.numTreesPerInteraction); ```

The second input argument specifies the maximum number of interaction terms, and the `NumTreesPerInteraction` name-value argument specifies the maximum number of trees per interaction term. The `addInteractions` function can include fewer interaction terms and stop before training the requested number of trees. You can use the `Interactions` and `ReasonForTermination` properties to check the actual number of interaction terms and number of trees in the trained model.

Display the interaction terms in `Mdl`.

`Mdl.Interactions`
```ans = 4×2 7 10 4 7 7 9 5 10 ```

Each row of `Interactions` represents one interaction term and contains the column indexes of the predictor variables for the interaction term. You can use the `Interactions` property to check the interaction terms in the model and the order in which `fitcgam` adds them to the model.

Display the interaction terms in `Mdl` using the predictor names.

`Mdl.PredictorNames(Mdl.Interactions)`
```ans = 4×2 cell {'relationship' } {'capital_gain'} {'education_num' } {'relationship'} {'relationship' } {'sex' } {'marital_status'} {'capital_gain'} ```

Display the reason for termination to determine whether the model contains the specified number of trees for each linear term and each interaction term.

`Mdl.ReasonForTermination`
```ans = struct with fields: PredictorTrees: 'Terminated after training the requested number of trees.' InteractionTrees: 'Terminated after training the requested number of trees.' ```

### Assess Predictive Performance on New Observations

Assess the performance of the trained model by using the test sample `adulttest` and the object functions `predict`, `loss`, `edge`, and `margin`. You can use a full or compact model with these functions.

If you want to assess the performance of the training data set, use the resubstitution object functions: `resubPredict`, `resubLoss`, `resubMargin`, and `resubEdge`. To use these functions, you must use the full model that contains the training data.

Create a compact model to reduce the size of the trained model.

```CMdl = compact(Mdl); whos('Mdl','CMdl')```
``` Name Size Bytes Class Attributes CMdl 1x1 3272176 classreg.learning.classif.CompactClassificationGAM Mdl 1x1 3389515 ClassificationGAM ```

Predict labels and scores for the test data set (`adulttest`), and compute model statistics (loss, margin, and edge) using the test data set.

```[labels,scores] = predict(CMdl,adulttest); L = loss(CMdl,adulttest,'Weights',adulttest.fnlwgt); M = margin(CMdl,adulttest); E = edge(CMdl,adulttest,'Weights',adulttest.fnlwgt);```

Predict labels and scores and compute the statistics without including interaction terms in the trained model.

```[labels_nointeraction,scores_nointeraction] = predict(CMdl,adulttest,'IncludeInteractions',false); L_nointeractions = loss(CMdl,adulttest,'Weights',adulttest.fnlwgt,'IncludeInteractions',false); M_nointeractions = margin(CMdl,adulttest,'IncludeInteractions',false); E_nointeractions = edge(CMdl,adulttest,'Weights',adulttest.fnlwgt,'IncludeInteractions',false);```

Compare the results obtained by including both linear and interaction terms to the results obtained by including only linear terms.

Create a table containing the observed labels, predicted labels, and scores. Display the first eight rows of the table.

```t = table(adulttest.salary,labels,scores,labels_nointeraction,scores_nointeraction, ... 'VariableNames',{'True Labels','Predicted Labels','Scores' ... 'Predicted Labels without interactions','Scores without interactions'}); head(t)```
```ans=8×5 table True Labels Predicted Labels Scores Predicted Labels without interactions Scores without interactions ___________ ________________ _____________________ _____________________________________ ___________________________ <=50K <=50K 0.97921 0.020787 <=50K 0.98005 0.019951 <=50K <=50K 1 8.258e-17 <=50K 0.9713 0.028696 <=50K <=50K 1 1.8297e-19 <=50K 0.99449 0.0055054 <=50K <=50K 0.87422 0.12578 <=50K 0.87729 0.12271 <=50K <=50K 1 3.5643e-07 <=50K 0.99882 0.0011769 <=50K <=50K 0.60371 0.39629 <=50K 0.77861 0.22139 <=50K >50K 0.49917 0.50083 >50K 0.46877 0.53123 >50K >50K 0.3109 0.6891 <=50K 0.53571 0.46429 ```

Create a confusion chart from the true labels `adulttest.salary` and the predicted labels.

```tiledlayout(1,2); nexttile confusionchart(adulttest.salary,labels) title('Linear and Interaction Terms') nexttile confusionchart(adulttest.salary,labels_nointeraction) title('Linear Terms Only')```

Display the computed loss and edge values.

```table([L; E], [L_nointeractions; E_nointeractions], ... 'VariableNames',{'Linear and Interaction Terms','Only Linear Terms'}, ... 'RowNames',{'Loss','Edge'})```
```ans=2×2 table Linear and Interaction Terms Only Linear Terms ____________________________ _________________ Loss 0.14868 0.13852 Edge 0.63926 0.58405 ```

The model achieves a smaller loss when only linear terms are included, but achieves a higher edge value when both linear and interaction terms are included.

Display the distributions of the margins using box plots.

```figure boxplot([M M_nointeractions],'Labels',{'Linear and Interaction Terms','Linear Terms Only'}) title('Box Plots of Test Sample Margins')```

### Interpret Prediction

Interpret the prediction for the first test observation by using the `plotLocalEffects` function. Also, create partial dependence plots for some important terms in the model by using the `plotPartialDependence` function.

Classify the first observation of the test data, and plot the local effects of the terms in `CMdl` on the prediction. To display an existing underscore in any predictor name, change the `TickLabelInterpreter` value of the axes to `'none'`.

`label = predict(CMdl,adulttest(1,:))`
```label = categorical <=50K ```
```f1 = figure; plotLocalEffects(CMdl,adulttest(1,:)) f1.CurrentAxes.TickLabelInterpreter = 'none';```

The `predict` function classifies the first observation `adulttest(1,:)` as `'<=50K'`. The `plotLocalEffects` function creates a horizontal bar graph that shows the local effects of the 10 most important terms on the prediction. Each local effect value shows the contribution of each term to the classification score for `'<=50K'`, which is the logit of the posterior probability that the classification is `'<=50K'` for the observation.

Create a partial dependence plot for the term `age`. Specify both the training and test data sets to compute the partial dependence values using both sets.

```figure plotPartialDependence(CMdl,'age',label,[adultdata; adulttest])```

The plotted line represents the averaged partial relationships between the predictor `age` and the score of the class `<=50K` in the trained model. The `x`-axis minor ticks represent the unique values in the predictor `age`.

Create partial dependence plots for the terms `education_num` and `relationship`.

```f2 = figure; plotPartialDependence(CMdl,["education_num","relationship"],label,[adultdata; adulttest]) f2.CurrentAxes.TickLabelInterpreter = 'none';```

The plot shows the partial dependence on `education_num`, which has a different trend depending on the `relationship` value.