- Use parallel processing over the query points by using a "parfor" loop over the query points.
- Use parallel processing within the processing for one query point by setting 'UseParallel' to 'true' in the call to 'fit'. If you have already implemented parallelization over the query points with parfor, then requesting parallelization within the computation for each query point may have little effect and it could make the computation slower in some cases where the overhead from parallelization at this level is greater than the gains. See doc links below.
- Reduce the number of background samples. It is often sufficient to choose 100 random background samples from the predictor data. For example, that would make your first line something like "explainer = shapley(mdl, data( randsample(size(data,1),100) , :) );" See doc links below.
- Reduce the MaxNumSubsets parameter. If you have M predictors, the default MaxNumSubsets is min( 2^M, 2^10). You could reduce MaxNumSubsets to make the computation faster, at the expense of introducing some approximation into the calculation. See doc links below. Note, in R2023a, if you are using a tree model, and if the software is using the interventional-tree method, then it is recommended to not set "MaxNumSubsets" because it is not applicable to interventional-tree, and the software will switch to using interventional-kernel because the "MaxNumSubsets" is specified.
- R2023a has some Shapley computation improvements which can speed-up the computation for linear or tree models. You can try R2023a at matlab.mathworks.com
- Reduce the number of query points. If your dataset has a large number of observations, then observing the Shapley values over 1000 random query points may be sufficient for the purpose.
Get Shapley values for the whole model
조회 수: 19 (최근 30일)
이전 댓글 표시
I'm trying to get the Shapley values for my classification model so that I can estimate the feature importances. However, based on the documentation, MATLAB's implementation only lets me get the values for a single row in my dataset. But to get the values for the entire dataset and model, I need to loop over the rows like this
explainer = shapley(mdl, data);
% Create matrix of zeros to store the values for -1 and 1 classes
values = zeros(size(data, 2), 2, size(data, 1));
for row=1:size(data,1)
temp_explainer = fit(explainer, data(row,:));
svalues = temp_explainer.ShapleyValues;
values(:,:,row) = table2array(svalues(:, 2:3));
end
% Average Shapley values across rows for each feature
feature_importances = mean(values, 3);
This is pretty slow. Is there a faster way to get the Shapley values for the entire model and dataset?
댓글 수: 0
채택된 답변
Drew
2023년 4월 18일
편집: Drew
2023년 4월 18일
For Shapley-based feature importance, it is recommended to use the mean absolute Shapley values. So, insert an absolute value function in the last line: mean(abs(values), 3).
In general, you can reduce Shapley computational cost through some combination of these methods:
Section on Shapley computational cost: https://www.mathworks.com/help/stats/shapley-values-for-machine-learning-model.html#mw_6fcd70ce-79f9-4075-a57b-af35dcebc171
Section on reducing Shapley computational cost: https://www.mathworks.com/help/stats/shapley-values-for-machine-learning-model.html#mw_d133b8fd-6c9e-4973-af4f-931781381497
MATLAB answers post about creating Shapley summary plots: https://www.mathworks.com/matlabcentral/answers/1578665-how-can-i-get-a-shapley-summary-plot
If this answer helps you, please remember to accept the answer.
댓글 수: 0
추가 답변 (0개)
참고 항목
카테고리
Help Center 및 File Exchange에서 Classification Trees에 대해 자세히 알아보기
Community Treasure Hunt
Find the treasures in MATLAB Central and discover how the community can help you!
Start Hunting!