删除 model_parts() 图中的变量

Remove variable in model_parts() plot

我想从图中删除某些变量。

# Packages 
library(tidymodels)
library(mlbench)

# Data 
data("PimaIndiansDiabetes")
dat <- PimaIndiansDiabetes 
dat$some_new_group[1:384] <- "group 1" 
dat$some_new_group[385:768] <- "group 2"

# Split
set.seed(123)
ind <- initial_split(dat)
dat_train <- training(ind)
dat_test <- testing(ind)

# Recipes
svm_rec <- 
  recipe(diabetes ~., data = dat_train) %>% 
  update_role(some_new_group, new_role = "group_var") %>% 
  step_rm(pressure) %>% 
  step_YeoJohnson(all_numeric_predictors())
    
# Model spec 
svm_spec <- 
  svm_rbf() %>% 
  set_mode("classification") %>% 
  set_engine("kernlab")

# Workflow 
svm_wf <- 
  workflow() %>% 
  add_recipe(svm_rec) %>% 
  add_model(svm_spec)

# Train
svm_trained <- 
  svm_wf %>% 
  fit(dat_train)

# Explainer
library(DALEXtra)

svm_exp <- explain_tidymodels(svm_trained, 
                              data = dat %>% select(-diabetes), 
                              y = dat$diabetes %>% as.numeric(), 
                              label = "SVM")
# Variable importance
set.seed(123)
svm_vp <- model_parts(svm_exp, type = "variable_importance") 
svm_vp

plot(svm_vp) +
  ggtitle("Mean-variable importance over 50 permutations", "") 

请注意,在上面的食谱中,我删除了变量 pressure 并创建了一个新的分类变量 (some_new_group)。

所以,我可以像这样从图中手动删除变量 pressure some_new_group

plot(svm_vp %>% filter(variable != c("pressure", "some_new_group"))) +
  ggtitle("Mean-variable importance over 50 permutations", "") 

但是,当我 运行 explain_tidymodels()model_parts() 时是否可以删除变量?

如果您的变量不是由您的 workflow() 处理的预测变量或结果(例如您删除的变量和您的分组变量),您需要确保 only pass outcomes and predictors to explain_tidymodels()。您还需要使用 parsnip 模型构建解释器,而不是期望处理那些 non-outcome、non-predictor 变量的 workflow()

library(tidymodels)

# Data 
data("PimaIndiansDiabetes", package = "mlbench")
dat <- PimaIndiansDiabetes 
dat$some_new_group[1:384] <- "group 1" 
dat$some_new_group[385:768] <- "group 2"

# Split
set.seed(123)
ind <- initial_split(dat)
dat_train <- training(ind)
dat_test <- testing(ind)

# Recipes
svm_rec <- 
  recipe(diabetes ~., data = dat_train) %>% 
  update_role(some_new_group, new_role = "group_var") %>% 
  step_rm(pressure) %>% 
  step_YeoJohnson(all_numeric_predictors())

# Model spec 
svm_spec <- 
  svm_rbf() %>% 
  set_mode("classification") %>% 
  set_engine("kernlab")

# Train
svm_trained <- 
  workflow(svm_rec, svm_spec) %>% 
  fit(dat_train)

# Explainer
library(DALEXtra)
#> Loading required package: DALEX
#> Welcome to DALEX (version: 2.4.0).
#> Find examples and detailed introduction at: http://ema.drwhy.ai/
#> 
#> Attaching package: 'DALEX'
#> The following object is masked from 'package:dplyr':
#> 
#>     explain

svm_exp <- explain_tidymodels(
  extract_fit_parsnip(svm_trained), 
  data = svm_rec %>% prep() %>% bake(new_data = NULL, all_predictors()), 
  y = dat_train$diabetes %>% as.numeric(), 
  label = "SVM"
)
#> Preparation of a new explainer is initiated
#>   -> model label       :  SVM 
#>   -> data              :  576  rows  7  cols 
#>   -> data              :  tibble converted into a data.frame 
#>   -> target variable   :  576  values 
#>   -> predict function  :  yhat.model_fit  will be used (  default  )
#>   -> predicted values  :  No value for predict function target column. (  default  )
#>   -> model_info        :  package parsnip , ver. 0.2.1 , task classification (  default  ) 
#>   -> predicted values  :  numerical, min =  0.08057345 , mean =  0.3540662 , max =  0.9357536  
#>   -> residual function :  difference between y and yhat (  default  )
#>   -> residuals         :  numerical, min =  0.1083522 , mean =  0.9948921 , max =  1.895405  
#>   A new explainer has been created!

# Variable importance
set.seed(123)
svm_vp <- model_parts(svm_exp, type = "variable_importance") 
svm_vp
#>       variable mean_dropout_loss label
#> 1 _full_model_         0.6861190   SVM
#> 2      glucose         0.5919956   SVM
#> 3         mass         0.6673947   SVM
#> 4     pregnant         0.6700007   SVM
#> 5          age         0.6701185   SVM
#> 6     pedigree         0.6702812   SVM
#> 7      triceps         0.6760106   SVM
#> 8      insulin         0.6777355   SVM
#> 9   _baseline_         0.5020752   SVM

plot(svm_vp) +
  ggtitle("Mean-variable importance over 50 permutations", "") 

reprex package (v2.0.1)

创建于 2022-05-03

如果您的工作流程中有这些不应用于解释的“额外”变量,那么您将需要做一些额外的工作,不能单独依赖 workflow()