image thumbnail

Explainable AI: interpreting the classification using LIME

version 2.1 (3.69 MB) by Kenta
This demo shows how to interpret the classification by CNN using LIME (Local Interpretable Model-agnostic Explanations) [1]. LIMEによる特徴量の可視化

220 Downloads

Updated 24 Jul 2021

From GitHub

View license on GitHub

View Explainable AI: interpreting the classification using LIME on File Exchange

Explainable AI: interpreting the classification performed by deep learning with LIME

[English]
This demo shows how to interpret the classification by CNN using LIME (Local Interpretable Model-agnostic Explanations) [1]. This demo was created based on [1], but the implementation might be a little bit different from its official one. This code highlights the regions that contributed to the classification. It helps you interpret and improve the model, or you can recoginize the classifier is not untrustworthy for you if the region highlighted is irrelevant for the true class.

[Japanese]
LIME [1]と呼ばれる手法を用いて、深層学習による画像分類を行った際の重要箇所を可視化します。公式の実装とは若干の違いがあるかもしれませんのでご注意ください。
サムネイルでは、学習済みネットワーク(ResNet-18 [2])が、ゴールデンレトリバーと予測したときの判断根拠の位置を可視化しています。

[Key words]
classification, cnn (convolutional neural network), deep learning, explainable AI, image, interpret, LIME (Local Interpretable Model-agnostic Explanations), machine learning, superpixel, visualization, why

[Reference]
[1] Ribeiro, M.T., Singh, S. and Guestrin, C., 2016, August. " Why should I trust you?" Explaining the predictions of any classifier. In Proceedings of the 22nd ACM SIGKDD international conference on knowledge discovery and data mining (pp. 1135-1144).

[2] He, K., Zhang, X., Ren, S. and Sun, J., 2016. Deep residual learning for image recognition. In Proceedings of the IEEE conference on computer vision and pattern recognition (pp. 770-778).

image_0.png

Load the image

clear;clc;close all
% read the target image
I=imread('img.png');
figure;imshow(I);title('target image')

figure_0.png

Importing pre-trained network, ResNet-18

net=resnet18;

Classify the image and confirm if the result is correct.

Ypred=classify(net,imresize(I,[224 224]))
Ypred = 
     Egyptian cat 

Extract the index corresponding to the classification result.

classIdx=find(net.Layers(71, 1).Classes==Ypred);

Create superpixels

First, the target image is divided into superpixels. Note that the final result is dependent on this segmentation. We have to determine proper parameters for this.

Calculate superpixels of the image.

numSuperPixel=75;
[L,N] = superpixels(I,numSuperPixel);

For the detail of the superpixel generation, please type as follows and confirm the algorithm.

% doc superpixels

Display the superpixel boundaries overlaid on the original image.

figure
BW = boundarymask(L);
imshow(imoverlay(I,BW,'cyan'),'InitialMagnification',100)

figure_1.png

Sampling for Local Exploration

This section creates pertubated image as shown below. Each superpixel was assigned 0 or 1 where the superpixel with 1 is displayed and otherwise colored by black.

image_1.png

% the number of the process to make perturbated images
% higher number of sampleNum leads to more reliable result with higher
% computation cost
sampleNum=1000;
% calculate similarity with the original image
similarity=zeros(sampleNum,1);
indices=zeros(sampleNum,N);
img=zeros(224,224,3,sampleNum);
for i=1:sampleNum
    % randomly black-out the superpixels
    ind=rand(N,1)>rand(1)*.8;
    map=zeros(size(I,1:2));
    for j=[find(ind==1)]'
        ROI=L==j;
        map=ROI+map;
    end  
    img(:,:,:,i)=imresize(I.*uint8(map),[224 224]);
    % calculate the similarity
    % other metrics for calculating similarity are also fine
    % this calculation also affetcts to the result
    similarity(i)=1-nnz(ind)./numSuperPixel;
    indices(i,:)=ind;   
end

Predict the perturbated images using CNN model to interpret

Use activations function to explore the classification score for cat.

prob=activations(net,uint8(img),'prob','OutputAs','rows');
score=prob(:,classIdx);

Fitting using weighted linear model

Use fitrlinear function to perform weighted linear fitting. Specify the weight like 'Weights',similarity. The input indices represents 1 or 0. For example, if the value of the variable "indices" is [1 0 1] , the first and third superpixels are active and second superpixel is masked by black. The label to predict is the score with each perturbated image. Note that this similarity was calculated using Kernel function in the original paper.

sigma=.35;
weights=exp(-similarity.^2/(sigma.^2));
mdl=fitrlinear(indices,score,'Learner','leastsquares','Weights',weights);

Confirm the exponential kernel used for the weighting.

x=[0:0.01:1];
y=(exp(-x.^2/(sigma.^2)));
figure;plot(x,y)

figure_2.png

Displaying the result

This result is just an example of LIME-based approach. This result can be changed with different parameter settings such as in superpixel generation, fitting method (I used linear fitting model) and parameters for fitting.

result=zeros(size(L));
for i=1:N
    ROI=L==i;
    result=result+ROI.*max(mdl.Beta(i),0);% calculate the contribution if the weight is non-zero
end

% smoothing the LIME result. this is not included in the official
% implementation
result2=imgaussfilt(result,8);
% display the final result
figure;imshow(I);hold on
imagesc(result2,'AlphaData',0.5);
colormap jet;colorbar;hold off;
title("Explanation using LIME");

figure_3.png

Cite As

Kenta (2021). Explainable AI: interpreting the classification using LIME (https://github.com/KentaItakura/Explainable-AI-interpreting-the-classification-performed-by-deep-learning-with-LIME-using-MATLAB/releases/tag/v2.1), GitHub. Retrieved .

MATLAB Release Compatibility
Created with R2020a
Compatible with any release
Platform Compatibility
Windows macOS Linux

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!
To view or report issues in this GitHub add-on, visit the GitHub Repository.
To view or report issues in this GitHub add-on, visit the GitHub Repository.