无法使用 SHAP 显示条形图

Can't display bar plot with SHAP

我是 SHAP 的新手,正在尝试在我的 RandomForestClassifier 之上使用它。这是我已经 运行 clf.fit(train_x, train_y):

之后的代码片段
explainer = shap.Explainer(clf)
shap_values = explainer(train_x.to_numpy()[0:5, :])
shap.summary_plot(shap_values, plot_type='bar')

这是结果图:

现在,这有两个问题。一是它不是条形图,即使我设置了 plot_type 参数。另一个是我似乎以某种方式丢失了我的特征名称(是的,当调用 clf.fit() 时它们确实存在于数据帧中)。

我尝试将最后一行替换为:

shap.summary_plot(shap_values, train_x.to_numpy()[0:5, :], plot_type='bar')

这并没有改变什么。我还尝试用以下内容替换它,看看我是否至少可以恢复我的功能名称:

shap.summary_plot(shap_values, train_x.to_numpy()[0:5, :], feature_names=list(train_x.columns.values), plot_type='bar')

但这引发了一个错误:

Traceback (most recent call last):
  File "sklearn_model_runs.py", line 41, in <module>
    main()
  File "sklearn_model_runs.py", line 38, in main
    shap.summary_plot(shap_values, train_x.to_numpy()[0:5, :], feature_names=list(train_x.columns.values), plot_type='bar')
  File "C:\Users\kapoo\anaconda3\envs\sci\lib\site-packages\shap\plots\_beeswarm.py", line 554, in summary_legacy
    feature_names=feature_names[sort_inds],
TypeError: only integer scalar arrays can be converted to a scalar index

此时我有点不知所措。我刚刚尝试了 5 行训练集,但一旦我越过了这个绊脚石,我就想使用整个训练集。如果有帮助,分类器有 5 个标签,我的 SHAP 版本是 0.40.0。

好吧,问题来了。替换为:

shap_values = explainer(train_x.to_numpy()[0:5, :])

有了这个:

shap_values = explainer.shap_values(train_x) # Use whole thing as dataframe

然后你可以在绘图过程中使用它:

feature_names=list(train_x.columns.values)

文档 here 确实应该更新...