How to get shapley value for Neural Network trained on matlab? it keeps error...
4 visualizaciones (últimos 30 días)
Mostrar comentarios más antiguos
Hi there,
I wanted to get shapley value of my pre-trained ANN.
it is regression model.
it's input's shape is 7*5120 double
and output is 1*5120 double.
I'm confused with idea of shapley.. sorry
3 comentarios
Angelo Yeo
el 26 de Ag. de 2024
Can you be more specific about your model and the error message? It's the best if you can share your model (code and data) and the reproduction steps for the error.
Respuestas (1)
Angelo Yeo
el 26 de Ag. de 2024
I do not have the model and dataset, so I used a random samples. The key is to use yticklabels. Would this work for you?
clc;
clear;
%% Shapley 값 계산
% demo neural network
x = randn(7, 150);
t = randn(1, 150);
net = fitnet(10);
net = configure(net,x,t);
% view(net)
f = @(x) net(x')'; % 인공신경망 모델 함수를 정의
x_veri_shapley = x(:,101:end)'; % 각 행이 하나의 샘플이 되도록 전치
x_train_shapley = x(:, 1:100)'; % 각 행이 하나의 샘플이 되도록 전치
% 샘플링 예시
num_samples = size(x_veri_shapley,1); % 샘플링할 데이터 수
idx = randperm(size(x_veri_shapley, 1), num_samples);
x_veri_shapley_sampled = x_veri_shapley(idx, :);
% % 병렬 처리 활성화
explainer = shapley(f, x_train_shapley, 'QueryPoints', x_veri_shapley_sampled,'UseParallel', false);
%%
% plot(explainer)
%%
% MeanAbsoluteShapley table을 복사
shapley_table = explainer.MeanAbsoluteShapley;
% 변수 이름 변경
desired_variable_names = ["PGA", "Dur_{sig}", "Sa_{max}", "Tm", "CAV_{max}", "Arias_{max}", "f_{1}"];
shapley_table.Predictor = desired_variable_names(:); % 새 변수 이름으로 교체
% Shapley 값과 변수 이름을 Shapley 값의 내림차순으로 정렬
[sorted_values, sort_index] = sort(shapley_table.ShapleyValue, 'ascend');
sorted_names = shapley_table.Predictor(sort_index);
% 막대 그래프 그리기 (큰 값부터 작은 값 순서로)
% figure;
% barh(sorted_values);
% set(gca, 'YTickLabel', sorted_names);
% xlabel('Shapley 절댓값의 평균');
% ylabel('예측 변수');
% title('Shapley 중요도 플롯');
%%
close all;
figure(10);
plot(explainer,QueryPointIndices=30);
hAxes = gca;
hAxes.TickLabelInterpreter = "tex";
yticklabels(hAxes, sorted_names) % use "yticklabels" to change the YTickLabels
figure(11);
plot(explainer);
hAxes = gca;
hAxes.TickLabelInterpreter = "tex";
yticklabels(hAxes, sorted_names) % use "yticklabels" to change the YTickLabels
figure(12);
swarmchart(explainer);
hAxes = gca;
hAxes.TickLabelInterpreter = "tex";
yticklabels(hAxes, sorted_names) % use "yticklabels" to change the YTickLabels
0 comentarios
Ver también
Categorías
Más información sobre Function Approximation and Nonlinear Regression en Help Center y File Exchange.
Productos
Community Treasure Hunt
Find the treasures in MATLAB Central and discover how the community can help you!
Start Hunting!