필터 지우기
필터 지우기

Shapley values for newff model

조회 수: 12 (최근 30일)
abouzar jafari
abouzar jafari 2024년 6월 10일
댓글: Drew 2024년 6월 12일
Hi there, I trained a BP nerual network using newff function, and want to obtain its Shapley values and swarmchart graph. But I have got the following error:
Code:
P = [2 3 1;3 4 5;1 3 4;4 6 7;2 7 3]';
T = [1 2 3 4 5];
net=newff(P,T,5,{'tansig' 'purelin'},'trainlm');
[net,tr]=train(net,P,T);
queryPoint = P(:,1);
explainer1 = shapley(net,P,'QueryPoint',queryPoint);
Error:
Error using shapley (line 233)
Blackbox model must be a classification model, regression model, or function handle
Error in untitled (line 10)
explainer1 = shapley(net,P,'QueryPoint',queryPoint);

채택된 답변

Drew
Drew 2024년 6월 10일
편집: Drew 2024년 6월 12일
Answer to the initial question:
The short answer is that you can obtain the Shapley values of a model created by newff by using function handles (more info below). Consider also if it might be more convenient to use the Classification Learner app and fitcnet, if the classification neural networks trained by fitcnet will meet your use case.
A table of the model object types that are supported as the "blackbox model" argument to the shapley function can be seen in the shapley documentation at https://www.mathworks.com/help/stats/shapley.html#mw_c2327b12-104d-48ef-8a71-1f0e8769549b . Models created with newff are not on that list. An arbitrary model (such as the model created with newff) can be used with the shapley function by using a function handle for model prediction. That is mentioned in the shapley doc section referenced immediately above: "Function handle — You can specify a function handle that accepts predictor data and returns a column vector containing a prediction for each observation in the predictor data. The prediction is a predicted response for regression or a predicted score of a single class for classification. You must provide the predictor data using X." See the "Specify Blackbox Model Using Function Handle" section of the shapley documentation page for an example: https://www.mathworks.com/help/stats/shapley.html#mw_c8688099-08ba-41d5-8a6c-0785a609b341
Other observations:
  • The function newff is quite old and was "Obsoleted in R2010b NNET 7.0". Consider using instead either feedforwardnet from the Deep Learning Toolbox, or fitcnet from the SMLT toolbox. Models created with fitcnet can be used as the "blackbox" model with the shapley function, so, if the neural networks available using fitcnet cover your use case, that could be the easiest path to do the Shapley analysis. You can also use the Classification Learner app to build fitcnet models.
  • The example provided in the question has very little data, with only one training example for each target class. More observation data will be needed for meaningful training, validation, and testing.
Answer using the data provided in the first comment:
The data and model provided in the ANN1LogSig.mat file attached to the first comment show that this is a regression problem with 10 predictors. Using this data we can both: (1) Get the Shapley summary swarmchart for the net model provided, and, (2) Use Regression Learner app to easily build many models on this data, and do Shapley analysis.
(1) Get the Shapley summary swarmchart for the net model provided, using R2024a or higher.
After loading the data and model, the swarmchart can be created in just 3 lines of code.
% Load data and model. This is the .mat file attached to the first comment
% below.
data = load('ANN1LogSig.mat');
% Calculate shapley values for the model that is already built and loaded
% Define the function handle to specify the blackbox model.
% Shapley expects observations in rows, and net expects observations in
% columns, hence the two transpose operations.
f = @(x) data.net(x')';
% Create SHAP explainer. Use the training data as background samples.
% Use the test data as Query points. One could alternately use
% training data as query points, or all the data as query points.
explainer = shapley(f, data.trainInputs', queryPoints=data.testInputs');
% Shapley Summary
figure(1); swarmchart(explainer,ColorMap="BlueRed");
% Shapley Importance
figure(2); plot(explainer)
% Shapley local plot for one query point
figure(3);plot(explainer, QueryPointIndices=1);
(2) Use Regression Learner app to easily build many models on this data, and do Shapley analysis.
% Prep data to send to Regression Learner
% For this, we will use the same test set (106) as in the mat file, but we
% will use train(499)+validation(106) as the training data (605), and then
% use 10-fold cross-validation in the Regression Learner app.
data = load('ANN1LogSig.mat');
% For all of the following, there are 11 columns, with predictors in
% columns 1-10, and the target in column 11.
allDataMatrix=data.ALL_DataNormal;
trainDataMatrix=allDataMatrix(1:605,:);
testDataMatrix=allDataMatrix(606:711,:);
% Start Regression Learner app
regressionLearner
In the session start dialog for Regression Learner, choose the trainDataMatrix variable, and choose 10-fold cross-validation. Do not set aside a test set in the session start dialog. After the session start, import the "testDataMatrix" using the "test" tab in the app.
Regression Learner makes it very easy train many types of models. Use the "All" preset to try many model types, then try the various "Optimizable" presets to optimize hyperparameters for some model types. After training and testing many models (without writing any code), looking at the "Compare Results" plot, we see that the Gaussian Process Regression models have the best RMSE using cross validation:
After choosing Model 7 based on it having the lowest validation RMSE, we can check the corresponding performance on the test data in order to compare with the test RMSE performance of the net model in the ANN1LogSig.mat.
Model 7, which has the lowest validation RMSE, has a test RMSE of 0.0337. This test RMSE of 0.0337 is less than half of the test RMSE of 0.0754 given by net model in the ANN1LogSig.mat.
>> data = load('ANN1LogSig.mat');
>> data.RMSE_ts
ans =
0.0754
Starting in R2024b, additional Shapley plots, including the Shapley Summary swarmchart, are available in the Classification Learner and Regression Learner apps. Users can try this in the R2024b prerelease which is scheduled to be available in the second half of June 2024. Below is a view of a Shapley Summary plot of the GPR model mentioned above in the Regression Learner app. In the meantime, one can always export a model from the Classification or Regression Learner apps, and then use the command line shapley and shapley.swarmchart commands to create the Shapley Summary plot, as illustrated above.
If this answer helps you, please remember to accept the answer.
  댓글 수: 4
abouzar jafari
abouzar jafari 2024년 6월 12일
@Drew, Thank you very much! Although it works, it's extremly slow. Any recommendation for make it faster?
Drew
Drew 2024년 6월 12일
The computation time is linearly related to the number of background observations. In 24b, the background observations will be randomly subsampled to NumObservationsToSample observations, for which the default is 100. This results in an approximation to the Shapley values, but the degree of accuracy in the Shapley values when using 100 background samples is generally considered high enough for most purposes. For example, when using 100 background samples, versus 499 background samples, the Shapley summary plot looks basically the same, but the computation is ~5x faster.
The same change in computation can be done in 24a by subsampling the background/reference samples before passing to shapley:
data = load('ANN1LogSig.mat');
f = @(x) data.net(x')';
backgroundSamples = data.trainInputs(:,randsample(size(data.trainInputs,2),100))';
tstart=tic;
explainer = shapley(f, backgroundSamples, queryPoints=data.testInputs');
toc(tstart);
swarmchart(explainer,ColorMap="BlueRed");
With this change, on one of my machines, the Shapley values are calculated in about 40 seconds on R2024a.
The calculation can be made faster again by using "useParallel=true", which takes advantage of the Parallel Computing Toolbox (PCT). For this older type of neural network model, only PCT process pools are supported. So, set the pool type to "Processes" in the "Parallel Computing Toolbox" section of the MATLAB preferences, then start the process pool. Starting the process pool can take more than a minute. Once an 8 process pool is started on the same machine mentioned above, calculating the Shapley values takes about 7 seconds on R2024a.
>> tstart=tic;
>> explainer = shapley(f, backgroundSamples, queryPoints=data.testInputs', UseParallel=true);
>> toc(tstart);
Elapsed time is 6.873876 seconds.

댓글을 달려면 로그인하십시오.

추가 답변 (0개)

카테고리

Help CenterFile Exchange에서 Consumer Credit Risk에 대해 자세히 알아보기

태그

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!

Translated by