Shap-values for multiclass classification

Hi! I have a question regarding SHAP-values. I trained a neural network using tensorflow for a multiclass classification. My possible outcomes are 0, 1 and 2. Now I want to derive the variable importance using SHAP-values. My model was trained on scaled (normalized) input features. Before, I was handling the problem as regression and I had non problem calculating the shap-values. When I apply it to my multi-class classification model I always get the error: TypeError: only integer scalar arrays can be converted to a scalar index.

The code I used is:
explainer = shap.Explainer(model.predict, X_train_scaled)
shap_values = explainer(X_test_scaled)
shap.summary_plot(shap_values)

Looking forward for any advise how to adjust the code! Thanks and best regards!

import shap

Assuming you have the scaled training data (X_train_scaled) and the test data (X_test_scaled)

and a trained model model

Initialize the explainer. You need to use the model’s output to get the class probabilities.

explainer = shap.KernelExplainer(model.predict, X_train_scaled)

Generate SHAP values for the test data. shap_values will now be a list of values for each class.

shap_values = explainer.shap_values(X_test_scaled)

SHAP values is a list, where each entry corresponds to the SHAP values for a specific class

We can visualize the SHAP values for all classes or select a specific class.

Visualize SHAP values for the first class (class 0):

shap.summary_plot(shap_values[0], X_test_scaled)

You can similarly visualize for other classes by indexing shap_values[1], shap_values[2], etc.

Hi! Thank you for your reply! I tried your approach and now I get the following error: AssertionError: The shape of the shap_values matrix does not match the shape of the provided data matrix.

Any advice why this happens? X_test_scaled and X_traing_scaled are both numpy arrays

If you’re going to post code on the forum, please use the “preformatted text” tag. Otherwise your code is rendered as markdown.

Hi @chippa

you are getting this error because of the issue with shap values gives shape of dimension and object type. So make sure your data is of all same class.

For the assertion error to resolve, please refer below link to do the corrections :point_right:t2: link

Based on previous type error, which explains that if your x_train is a list and you are using numpy array to index out a list data type, so that throws type error.

So make sure first convert your data to array form

something like

np.array(X_train)[indices.astype(int)]

Hope this helps!! Your main issue is your data type is not of same data type, i.e. int, float, complex.

Regards
DP