How to obtain Shapley Values with a pattern classification neural network?

조회 수: 3 (최근 30일)
We wanted to obtain the Shapley Values for a feature vector (query point) and a database trained with a neuralnet initiated with patternet.
As our neuralnet is a “pattern recognition neural network” We obtain this error message:
Error using shapley (line XX)
Blackbox model must be a classification model, regression model, or function handle
So, my question is: Is there any option to compute the Shapley values for a “pattern recognition neural network” model
Or instead, can We convert a “pattern recognition neural network” into a “classification neural network” in order to compute their Shappey values?
Thanks in any case.

채택된 답변

Amanjit Dulai
Amanjit Dulai 2022년 7월 18일
It is possible to do this by passing a function handle to shapley. This function handle needs to output the score for the class of interest. Also, shapley expects inputs and outputs for the function handle to be row vectors rather than column vectors, so some transposes are needed. Below is an example using the Fisher Iris data:
% Train a neural network on the iris data
[x,t] = iris_dataset;
net = patternnet(10);
net = train(net,x,t);
% Choose an observation to explain. We need its class as an index.
x1 = x(:,1);
t1 = find(t(:,1));
% Plot shapley values. For Setosa (the first class) the petal length (x3)
% is usually the most informative feature.
explainer = shapley( ...
@(x)predictScoreForSpecifiedClass(net,x,t1), ...
x', "QueryPoint", x1' );
plot(explainer)
% Helpers
function score = predictScoreForSpecifiedClass(net, x, classIndex)
Y = net(x');
score = Y(classIndex,:)';
end

추가 답변 (0개)

카테고리

Help CenterFile Exchange에서 Sequence and Numeric Feature Data Workflows에 대해 자세히 알아보기

제품


릴리스

R2021a

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!

Translated by