如何从 TreeExplainer 获取 shap_values 的特征名称?

How to get feature names of shap_values from TreeExplainer?

我正在做一个 shap 教程,并尝试获取数据集中每个人的 shap 值

from sklearn.model_selection import train_test_split
import xgboost
import shap
import numpy as np
import pandas as pd
import matplotlib.pylab as pl

X,y = shap.datasets.adult()
X_display,y_display = shap.datasets.adult(display=True)

# create a train/test split
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=7)
d_train = xgboost.DMatrix(X_train, label=y_train)
d_test = xgboost.DMatrix(X_test, label=y_test)
params = {
    "eta": 0.01,
    "objective": "binary:logistic",
    "subsample": 0.5,
    "base_score": np.mean(y_train),
    "eval_metric": "logloss"
}
#model = xgboost.train(params, d_train, 5000, evals = [(d_test, "test")], verbose_eval=100, early_stopping_rounds=20)

xg_clf = xgboost.XGBClassifier()
xg_clf.fit(X_train, y_train)
explainer = shap.TreeExplainer(xg_clf, X_train)
#shap_values = explainer(X)
shap_values = explainer.shap_values(X)

通过 Python3 解释器,shap_values 是一个包含 32,561 个人的庞大数组,每个人都有 12 个特征的 shap 值。

例如,第一个人的 SHAP 值如下:

>>> shap_values[0]
array([ 0.76437867, -0.11881508,  0.57451954, -0.41974955, -0.20982443,
       -0.38079952, -0.00986504,  0.32272505, -3.04392116,  0.00411322,
       -0.26587735,  0.02700199])

然而,哪个值适用于哪个功能对我来说完全是个谜。

文档说:

For models with a single output this returns a matrix of SHAP values
        (# samples x # features). Each row sums to the difference between the model output for that
        sample and the expected value of the model output (which is stored in the expected_value
        attribute of the explainer when it is constant). For models with vector outputs this returns
        a list of such matrices, one for each output

当我转到生成 shap_valuesexplainer 时,我看到我可以获得特征名称:

explainer.data_feature_names
['Age', 'Workclass', 'Education-Num', 'Marital Status', 'Occupation', 'Relationship', 'Race', 'Sex', 'Capital Gain', 'Capital Loss', 'Hours per week', 'Country']

但我看不到如何在 Python 解释器中获取 shap_values 中的特征名称,如果它们在那里的话:

>>> shap_values.
shap_values.all(           shap_values.compress(      shap_values.dump(          shap_values.max(           shap_values.ravel(         shap_values.sort(          shap_values.tostring(
shap_values.any(           shap_values.conj(          shap_values.dumps(         shap_values.mean(          shap_values.real           shap_values.squeeze(       shap_values.trace(
shap_values.argmax(        shap_values.conjugate(     shap_values.fill(          shap_values.min(           shap_values.repeat(        shap_values.std(           shap_values.transpose(
shap_values.argmin(        shap_values.copy(          shap_values.flags          shap_values.nbytes         shap_values.reshape(       shap_values.strides        shap_values.var(
shap_values.argpartition(  shap_values.ctypes         shap_values.flat           shap_values.ndim           shap_values.resize(        shap_values.sum(           shap_values.view(
shap_values.argsort(       shap_values.cumprod(       shap_values.flatten(       shap_values.newbyteorder(  shap_values.round(         shap_values.swapaxes(      
shap_values.astype(        shap_values.cumsum(        shap_values.getfield(      shap_values.nonzero(       shap_values.searchsorted(  shap_values.T              
shap_values.base           shap_values.data           shap_values.imag           shap_values.partition(     shap_values.setfield(      shap_values.take(          
shap_values.byteswap(      shap_values.diagonal(      shap_values.item(          shap_values.prod(          shap_values.setflags(      shap_values.tobytes(       
shap_values.choose(        shap_values.dot(           shap_values.itemset(       shap_values.ptp(           shap_values.shape          shap_values.tofile(        
shap_values.clip(          shap_values.dtype          shap_values.itemsize       shap_values.put(           shap_values.size           shap_values.tolist(    

我的主要问题:我怎样才能找出

中的哪个特征

['Age', 'Workclass', 'Education-Num', 'Marital Status', 'Occupation', 'Relationship', 'Race', 'Sex', 'Capital Gain', 'Capital Loss', 'Hours per week', 'Country']

适用于shap_values每行中的哪个数字?

>>> shap_values[0]
array([ 0.76437867, -0.11881508,  0.57451954, -0.41974955, -0.20982443,
       -0.38079952, -0.00986504,  0.32272505, -3.04392116,  0.00411322,
       -0.26587735,  0.02700199])

我假设这些特征的顺序相同,但我没有证据证明这一点。

我的第二个问题:如何在 shap_values 中找到特征名称?

这些功能确实与您假设的顺序相同;在 Github.

中查看 how to extract the most important feature names? and how to get feature names from explainer 个问题

要查找特征名称,您只需访问与名称数组具有相同索引的元素

例如:

features_names = np.array([
    0.76437867, -0.11881508,  0.57451954, -0.41974955, -0.20982443,
   -0.38079952, -0.00986504,  0.32272505, -3.04392116,  0.00411322,
   -0.26587735,  0.02700199])
features_names = ['Age', 'Workclass', 'Education-Num', 'Marital Status', 'Occupation',
                  'Relationship', 'Race', 'Sex', 'Capital Gain', 'Capital Loss',
                  'Hours per week', 'Country']

features_names[shap_values.argmin()]  # the index 8 -> Capital Gain
features_names[shap_values.argmax()]  # the index 0 -> Age