如何使用 matplotlib 转换 nd 数组以绘制特征重要性

How to transform nd array to plot feature importances using matplotlib

我正在使用 shap to determine feature importances for an mlp_classifier

(注意 - 我使用的是虚拟数据和模型,因为我的普通数据和模型是专有的)。

import shap, pandas as pd, numpy as np
from sklearn.neural_network import MLPClassifier
from sklearn import datasets
from sklearn.model_selection import train_test_split

iris = datasets.load_iris()
X = iris.data
y = iris.target

data = pd.DataFrame(data= np.c_[iris['data'], iris['target']],
                     columns= iris['feature_names'] + ['target'])
label = data['target']
data.drop('target', axis=1, inplace=True)
X_train, X_test, y_train, y_test = train_test_split(data, label,random_state=np.random.randint(1,10), test_size=0.3)

mlp = MLPClassifier(max_iter=150).fit(X_train, y_train)                                            
mlp.score(X_test, y_test)

explainer = shap.KernelExplainer(mlp.predict_proba, shap.kmeans(X_train, 5))
shap_values = explainer.shap_values(X_test)

# First plot
shap.summary_plot(shap_values[1], feature_names = X_test.columns, plot_type='bar')

# Second, error, empty plot
import matplotlib.pyplot as plt; plt.rcdefaults()
y_pos = np.arange(len(X_test.columns))
plt.bar(y_pos, shap_values[1], align='center', alpha=0.5)
plt.xticks(y_pos, X_test.columns)
plt.ylabel('SHAP Importance')
plt.title('MLP Feature Importances')

plt.show()

通过this guide使用shap.summary_plot,我得到如下图:

在我的实际数据集中,我有大约 10,000 个特征。我想要做的是从该列表中取出最重要的 n 特征,并使用 matplotlib 绘制它们,方法是遵循 this guide。但是,我得到一个错误和一个空白图表:

完整追溯:

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-42-ba03083152ca> in <module>
      1 import matplotlib.pyplot as plt; plt.rcdefaults()
      2 y_pos = np.arange(len(X_test.columns))
----> 3 plt.bar(y_pos, shap_values[1], align='center', alpha=0.5)
      4 plt.xticks(y_pos, X_test.columns)
      5 plt.ylabel('SHAP Importance')

c:\python367-64\lib\site-packages\matplotlib\pyplot.py in bar(x, height, width, bottom, align, data, **kwargs)
   2407     return gca().bar(
   2408         x, height, width=width, bottom=bottom, align=align,
-> 2409         **({"data": data} if data is not None else {}), **kwargs)
   2410 
   2411 

c:\python367-64\lib\site-packages\matplotlib\__init__.py in inner(ax, data, *args, **kwargs)
   1563     def inner(ax, *args, data=None, **kwargs):
   1564         if data is None:
-> 1565             return func(ax, *map(sanitize_sequence, args), **kwargs)
   1566 
   1567         bound = new_sig.bind(ax, *args, **kwargs)

c:\python367-64\lib\site-packages\matplotlib\axes\_axes.py in bar(self, x, height, width, bottom, align, **kwargs)
   2393                 edgecolor=e,
   2394                 linewidth=lw,
-> 2395                 label='_nolegend_',
   2396                 )
   2397             r.update(kwargs)

c:\python367-64\lib\site-packages\matplotlib\patches.py in __init__(self, xy, width, height, angle, **kwargs)
    725         """
    726 
--> 727         Patch.__init__(self, **kwargs)
    728 
    729         self._x0 = xy[0]

c:\python367-64\lib\site-packages\matplotlib\patches.py in __init__(self, edgecolor, facecolor, color, linewidth, linestyle, antialiased, hatch, fill, capstyle, joinstyle, **kwargs)
     87         self.set_fill(fill)
     88         self.set_linestyle(linestyle)
---> 89         self.set_linewidth(linewidth)
     90         self.set_antialiased(antialiased)
     91         self.set_hatch(hatch)

c:\python367-64\lib\site-packages\matplotlib\patches.py in set_linewidth(self, w)
    393                 w = mpl.rcParams['axes.linewidth']
    394 
--> 395         self._linewidth = float(w)
    396         # scale the dash pattern by the linewidth
    397         offset, ls = self._us_dashes

TypeError: only size-1 arrays can be converted to Python scalars

我无法将 shap_values[1] 中的 shap_value 映射到 X_test.columns 中的特定功能。我怎样才能正确地占据顶部 n 并在 matplotlib 中绘制?

shap_values[1] 是特征重要性。所以你可以把它变成一个系列:

# features importances sorted 
fi = (pd.Series(shap_values[1].mean(0), index=X_test.columns)
        .abs()
        .sort_values(ascending=False)
     )

# extract 5 top values and plot
fi.head(5).plot.barh()

输出: