How can I get a shapley summary plot?
96 visualizaciones (últimos 30 días)
Mostrar comentarios más antiguos
Fazel Bateni
el 4 de Nov. de 2021
Comentada: Chirdeep N R
el 17 de Jun. de 2024
I have checekd the MATLAB syntaxes about the shapley value plots, but the examples didn't help me figure out how I can sketch a shapley summary plot similar to the attached image. Can you please help me out?
In python, you can use shap libraries to understand how much each input variable in the machine learning model contributes to the model prediction. But, I'm not able to have that flexibility in MATLAB.
Ref:https://towardsdatascience.com/explain-your-model-with-the-shap-values-bc36aac4de3d
2 comentarios
Haidong Zhao
el 16 de Jun. de 2022
Hi Fazel, did you solve this question? I am also interested in it.
Best
haidong
Respuesta aceptada
Drew
el 18 de Feb. de 2023
Editada: Drew
el 8 de En. de 2024
At the high-level, the way to build this plot is:
(1) Load data and build the model
(2) Calculate the Shapley values using the shapley function (introduced in 21a)
(3) Create the Shapley summary plot
For R2024a and higher:
R2024a has new functionality to easily create shapley summary plots. This is described in the release notes https://www.mathworks.com/help/releases/R2024a/stats/release-notes.html.
(As of this writing in January 2024, R2024a can be accessed by using the prerelease. The general release of R2024a is planned for March 2024.)
Given a trained machine learning model, you can now use the shapley and fit functions to compute Shapley values for multiple query points. After using the QueryPoints name-value argument to specify multiple query points, you can visualize the results in the fitted shapley object by using the following object functions:
- plot — Plot mean absolute Shapley values using bar graphs.
- boxchart — Visualize Shapley values using box charts (box plots).
- swarmchart — Visualize Shapley values using swarm scatter charts.
Here is a regression example on the widely available wine quality data set which was pictured in the question.
% Load the data, choose response variable
alldata=readtable("winequality-red.csv");
response='quality';
% Split data into train and test. Choose 20% of the data for testing.
rng("default") % For reproducibility
c = cvpartition(size(alldata,1),"Holdout",0.20);
x_train=alldata(training(c),:);
x_test=alldata(test(c),:);
% Create a model which is an ensemble of bagged trees (random forest),
% using the training set
mdl=fitrensemble(x_train,response,'Method','Bag', ...
'Learners',templateTree('MaxNumSplits',126),'NumLearningCycles',10);
% Get Shapley values using all of the training data as background sample,
% and using all data (training and test) as query points.
explainer = shapley(mdl,x_train,'QueryPoints',alldata)
% Get visualization
figure(1);clf;
swarmchart(explainer,NumImportantPredictors=11,ColorMap='bluered')
Here is the resulting figure when using the default figure window size. Notice how easy it is to use a single line of code to calculate the shapley values across all of the query points, and a single line of code to create this plot.
Eleven predictors were squeezed onto that plot, so there is some undesirable overlap of data points from different predictors. That can easily be fixed in a variety of ways.
(1) If the figure window is simply resized to be sufficiently large, then the undersirable overlap of data points from different predictors goes away. Next, if desired, the font size of the labels can be increased. The image below is an example result. There are many variations depending on the figure window size and aspect ratio, and the size of the font for the labels.
(2) Alternately, one could get the handle to the underlying scatter object for each predictor, and adjust the data marker size, that is, the 'SizeData' property.
% We squeezed 11 features onto the plot. One way to adjust for that is to
% shrink the data marker size. We will cycle through the children of the axes and
% decrease the data marker size, that is, the 'SizeData' property.
% There is a scatter object for each predictor.
h=get(gca,'children');
ReductionFactorMarkerSize = 0.2; % Make the markers 20% of their original size
for i=1:length(h)
if (strcmp(class(h(i)),"matlab.graphics.chart.primitive.Scatter"))
% reduce marker size. The default size is 36
InitialMarkerSize = get(h(i),'SizeData');
set(h(i),'SizeData',InitialMarkerSize*ReductionFactorMarkerSize);
end
end
This also achieves the goal of avoiding overlap between data points for different predictors. Here is the resulting figure, at the default figure window size. This plot can easily be resized to adjust the look as desired.
For R2024a or higher, here is a classification example using the fisheriris dataset:
% Load the data, identify the response variable
alldata=readtable("fisheriris.csv");
response='Species';
% Split data into train and test. Choose 20% of the data for testing.
% Use cvpartition to stratify the sample via classes, so that the
% train and test sets have similar class distribution.
rng("default") % For reproducibility
c = cvpartition(alldata.(response),"Holdout",0.20);
x_train=alldata(training(c),:);
x_test=alldata(test(c),:);
% Create random forest model with 7 trees (NumLearningCycles),
mdl=fitcensemble(x_train,response,'Method','Bag','NumLearningCycles',7);
% Get Shapley values using the training set as the background sample,
% and "alldata" (train and test) as query points. Note that this multi-query-point
% call to the shapley function requires R2024a or higher
explainer = shapley(mdl, x_train, 'QueryPoints',alldata)
% Plot visualization of mean(abs(shap)) bar plot, and swarmchart for each
% output class. Note that these multi-query-point plots require R2024a or
% higher
figure(1); clf; tiledlayout(2,2); nexttile(1);
% Plot the mean(abs(shap)) plot for this multi-query-point shapley object
plot(explainer);
% Plot the shapley summary swarmchart for each output class
for i=2:4
nexttile(i);
swarmchart(explainer,ClassName=mdl.ClassNames{i-1},ColorMap='bluered')
end
Here is the resulting tiled plot which includes a shapley importance plot in the upper left, and a shapley summary plot swarmchart for each output class
For R2023b and earlier:
An example is shown below using the carsmall dataset, without using the new shapley functionality in R2024a. The two predictors Horsepower and Weight are used to predict the MPG (Miles Per Gallon) of a car.
(1) Load the data and build the model
% Follows example at https://www.mathworks.com/help/stats/train-regression-ensemble.html,
% with the addition of storing the data in a table.
load carsmall
x_train=table(Horsepower,Weight,MPG);
response='MPG';
% n = number of observations in the training set x_train
% m = number of columns in "x_train"
% d = number of features (predictors)
[n,m]=size(x_train);
d=m-1; % subtract one, because last column is the response
% Build Model
Mdl=fitrensemble(x_train,response,'Method','LSBoost','NumLearningCycles',100);
(2) Calculate Shapley values one query point at a time
% Initialize the shap matrix. The indices will be: (query-point-index, predictor-index)
shap=zeros(n,d);
% loop over the query points and calculate Shapley Values
% We can parallelize over query points with a "parfor" loop instead of a
% "for" loop. For this small example running inside MATLAB answers, we will
% use a "for" loop.
tic;
for i=1:n
% Set 'UseParallel' to true to parallelize inside the shapley function.
% For this small example running inside MATLAB answers, we will set
% 'UseParallel' to false.
explainer=shapley(Mdl,'QueryPoint',x_train(i,:),'UseParallel',false);
% Store the shapley values in a matrix referenced by (query-point-index, predictor-index)
shap(i,:)=explainer.ShapleyValues{:,2};
end
toc;
(3) Create the Shapley summary plot using multiple calls to "scatter"
% Sort the predictors by mean(abs(shap))
[sortedMeanAbsShapValues,sortedPredictorIndices]=sort(mean(abs(shap)));
% Loop over the predictors, plot a row of points for each predictor using
% the scatter function with "density" jitter.
% The multiple calls to scatter are needed so that the jitter is normalized
% per-row, rather than globally over all the rows.
for p=1:d
scatter(shap(:,sortedPredictorIndices(p)), ... % x-value of each point is the shapley value
p*ones(n,1), ... % y-value of each point is an integer corresponding to a predictor (to be jittered below)
[], ... % Marker size for each data point, taking the default here
normalize(table2array(x_train(:,sortedPredictorIndices(p))),'range',[1 256]), ... % Colors based on feature values
'filled', ... % Fills the circles representing data points
'YJitter','density', ... % YJitter according to the density of the points in this row
'YJitterWidth',0.8)
if (p==1) hold on; end
end
title('Shapley Summary plot');
xlabel('Shapley Value (impact on model output)')
yticks([1:d]);
yticklabels(x_train.Properties.VariableNames(sortedPredictorIndices));
% Set colormap as desired
colormap(CoolBlueToWarmRedColormap); % This colormap is like the one used in many Shapley summary plots
% colormap(parula); % This is the default colormap
cb= colorbar('Ticks', [1 256], 'TickLabels', {'Low', 'High'});
cb.Label.String = "Scaled Feature Value";
cb.Label.FontSize = 12;
cb.Label.Rotation = 270;
set(gca, 'YGrid', 'on');
xline(0, 'LineWidth', 1);
hold off;
A few notes:
- For Weight and Horsepower, there are many query points where high values of those features have negative Shapley values. This is as expected, since high values for those predictors will generally tend to reduce the MPG of a car.
- The Shapley summary plot colorbar can be extended to categorical features by mapping the categories to integers using the "unique" function, e.g., [~, ~, integerReplacement]=unique(originalCategoricalArray).
- For classification problems, a Shapley summary plot can be created for each output class. In that case, the shap variable could be a tensor ("3-D matrix") with indices as: (query-point-index, predictor-index, output-class-index)
Function to create CoolBlueToWarmRedColormap
function colormap = CoolBlueToWarmRedColormap()
% Define start point, middle luminance, and end point in L*ch colorspace
% https://www.mathworks.com/help/images/device-independent-color-spaces.html
% The three components of L*ch are Luminance, chroma, and hue.
blue_lch = [54 70 4.6588]; % Starting blue point
l_mid = 40; % luminance of the midpoint
red_lch = [54 90 6.6378909]; % Ending red point
nsteps = 256;
% Build matrix of L*ch colors that is nsteps x 3 in size
% Luminance changes linearly from start to middle, and middle to end.
% Chroma and hue change linearly from start to end.
lch=[[linspace(blue_lch(1), l_mid, nsteps/2), linspace(l_mid, red_lch(1), nsteps/2)]', ... luminance column
[linspace(blue_lch(2), red_lch(2), nsteps)]', ... chroma column
[linspace(blue_lch(3), red_lch(3), nsteps)]']; ... hue column
% Convert L*ch to L*a*b, where a = c * cos(h) and b = c * sin(h)
lab=[lch(:,1) lch(:,2).*cos(lch(:,3)) lch(:,2).*sin(lch(:,3))];
% Convert L*a*b to RGB
colormap=lab2rgb(lab,'OutputType','uint8');
end
7 comentarios
Drew
el 15 de Jun. de 2024
Editada: Drew
el 16 de Jun. de 2024
There are lots of ways to get these two datasets. The fisher iris data is available in MATLAB and can be loaded with the MATLAB code shown in the example:
alldata=readtable("fisheriris.csv");
Since fisher iris ships with MATLAB, if you are using an installed version of MATLAB, you can load the csv file into another application (like excel) directly from the MATLAB installation directory. You can see where the "fisheriris.csv" file is located with the MATLAB command:
which fisheriris.csv
You can download fisher iris, or read it programmatically, from the UCI Machine Learning Repository https://archive.ics.uci.edu/dataset/53/iris.
You can download wine quality (red and white), or read it programmatically, from the UCI Machine Learning Repository https://archive.ics.uci.edu/dataset/186/wine+quality. Download and open the zip package, then find the winequality-red.csv file.
Más respuestas (0)
Ver también
Categorías
Más información sobre Classification Trees en Help Center y File Exchange.
Community Treasure Hunt
Find the treasures in MATLAB Central and discover how the community can help you!
Start Hunting!