为 tidymodel 对象创建 SHAP 图

Create SHAP plots for tidymodel objects

本题参考Obtaining summary shap plot for catboost model with tidymodels in R。根据问题下方的评论,OP 找到了解决方案,但目前尚未与社区分享。

我想用 tidymodels 包分析我的树集合,其中包含 SHAP 值图,例如像

这样的单一观察图

并总结我的数据集所有特征的效果,例如

DALEXtra 提供了为 tidymodels explain.tidymodels() 创建 SHAP 值的功能。 fastshap 包中的 force_plot 为底层 python 包 SHAP 的绘图函数提供了包装器。但是我不明白如何使函数与 explain.tidymodels() 函数的输出一起工作。

问题:如何使用 tidymodelsexplain.tidymodels 在 R 中生成这样的 SHAP 图?

MWE(对于 explain.tidymodels 的 SHAP 值)

library(MASS)
library(tidyverse)
library(tidymodels)
library(parsnip)
library(treesnip)
library(catboost)
library(fastshap)
library(DALEXtra)
set.seed(1337)
rec <-  recipe(crim ~ ., data = Boston)

split <- initial_split(Boston)

train_data <- training(split)

test_data <- testing(split) %>% dplyr::select(-crim) %>% as.matrix()

model_default<-
  parsnip::boost_tree(
    mode = "regression"
  ) %>%
  set_engine(engine = 'catboost', loss_function = 'RMSE')
#sometimes catboost is not loaded correctly the following two lines
#ensure prevent fitting errors
#https://github.com/curso-r/treesnip/issues/21 error is mentioned on last post
set_dependency("boost_tree", eng = "catboost", "catboost")
set_dependency("boost_tree", eng = "catboost", "treesnip")

model_fit_wf <- model_fit_wf <- workflow() %>% add_model(model_tune) %>%  add_recipe(rec) %>% {parsnip::fit(object = ., data =  train_data)}

SHAP_wf <- explain_tidymodels(model_fit_wf, data = X, y = train_data$crim, new_data = test_data

也许这会有所帮助。至少,这是朝着正确方向迈出的一步。

首先,确保安装了 fastshap 和 reticulate(即 install.packages("..."))。接下来,设置虚拟环境并安装shap(pip install ...)。此外,为依赖关系图安装 matplotlib 3.2.2(查看 GitHub issues on this —— 需要旧版本的 matplotlib)。

RStudio 提供了有关虚拟环境设置的大量信息。也就是说,虚拟环境设置需要或多或少的故障排除,具体取决于 IDE 的使用情况。 (遗憾的是,由于许可,一些工作设置限制了开源 RStudio 的使用。)

库 (fastshap) 的文档在这方面也很有帮助。

这是 lightgbm 的工作流程(来自 treesnip 文档,稍作修改)。

library(tidymodels)
library(treesnip)

data("diamonds", package = "ggplot2")
diamonds <- diamonds %>% sample_n(1000)

# vfold resamples
diamonds_splits <- vfold_cv(diamonds, v = 5)

model_spec <- boost_tree(mtry = 5, trees = 500) %>% set_mode("regression")

# model specs
lightgbm_model <- model_spec %>% 
    set_engine("lightgbm", nthread = 6)

#workflows
lightgbm_wf <- workflow() %>% 
    add_model(
       lightgbm_model
    )

rec_ordered <- recipe(
    price ~ .
      , data = diamonds
) 

lightgbm_fit_ordered <- fit_resamples(
  add_recipe(
    lightgbm_wf, rec_ordered
    ), resamples = diamonds_splits)

在预测之前,我们想要适应我们的工作流程

fit_workflow <- lightgbm_wf %>% 
     add_recipe(rec_ordered) %>% 
     fit(data = diamonds)

现在我们有了合适的工作流程并且可以预测。要使用 fastshap::explain 函数,我们需要创建一个预测函数(这并不总是成立:取决于所使用的引擎,它可能会或可能不会开箱即用 - 请参阅文档)。

predict_function_gbm <-  function(model, newdata) {
    predict(model, newdata) %>% pluck(.,1)
}

让我们得到平均预测值(在下面使用)。这也可作为检查以确保功能正常运行。

mean_preds <- mean(
    predict_function_gbm(
       fit_workflow, diamonds %>% select(-price)
   )
)

现在我们创建我们的解释(形状值)。请注意此处的 pred_wrapper 和 X 参数(有关其他示例,请参阅 fastshap github 问题——即 glmnet)。

fastshap::explain( 
    fit_workflow, 
    X = as.data.frame(diamonds %>% select(-price)),
    pred_wrapper = predict_function_gbm, 
    nsim = 10
) -> explanations_gbm

这应该会产生力图。

fastshap::force_plot(
    object = explanations_gbm[1,], 
    feature_values = as.data.frame(diamonds %>% select(-price))[1,], 
    display = "viewer", 
    baseline = mean_preds) 

这允许多个垂直堆叠:

fastshap::force_plot(
    object = explanations_gbm[1:20,], 
    feature_values = as.data.frame(diamonds %>% select(-price))[1:20,], 
    display = "viewer", 
    baseline = mean_preds) 

添加link = "logit" 进行分类。将 Rmarkdown 渲染的显示更改为“html”。

现在用于汇总图和依赖图。

诀窍是使用网状结构直接访问函数。请注意,同样的逻辑适用于 transformers、numpy 等库

首先,对于依赖图。

library(reticulate)
shap = import("shap")
np = import("numpy") 

shap$dependence_plot(
     "rank(3)", 
     data.matrix(explanations_gbm),
     data.matrix(diamond %>% select(-price))
)

有关 rank(3) 的解释,请参阅 shap 文档 -- rank(1) 等也可以。

不幸的是,当我尝试直接命名该功能(即“cut”)时它抛出了一个错误。

现在是摘要情节:

shap$summary_plot( 
    data.matrix(explanations_gbm),
    data.matrix(diamond %>% select(-price))
)

最后说明:重复渲染绘图会产生错误的可视化效果。希望这为 catboost 可视化提供了一个出发点。