删除 model_parts() 图中的变量
Remove variable in model_parts() plot
我想从图中删除某些变量。
# Packages
library(tidymodels)
library(mlbench)
# Data
data("PimaIndiansDiabetes")
dat <- PimaIndiansDiabetes
dat$some_new_group[1:384] <- "group 1"
dat$some_new_group[385:768] <- "group 2"
# Split
set.seed(123)
ind <- initial_split(dat)
dat_train <- training(ind)
dat_test <- testing(ind)
# Recipes
svm_rec <-
recipe(diabetes ~., data = dat_train) %>%
update_role(some_new_group, new_role = "group_var") %>%
step_rm(pressure) %>%
step_YeoJohnson(all_numeric_predictors())
# Model spec
svm_spec <-
svm_rbf() %>%
set_mode("classification") %>%
set_engine("kernlab")
# Workflow
svm_wf <-
workflow() %>%
add_recipe(svm_rec) %>%
add_model(svm_spec)
# Train
svm_trained <-
svm_wf %>%
fit(dat_train)
# Explainer
library(DALEXtra)
svm_exp <- explain_tidymodels(svm_trained,
data = dat %>% select(-diabetes),
y = dat$diabetes %>% as.numeric(),
label = "SVM")
# Variable importance
set.seed(123)
svm_vp <- model_parts(svm_exp, type = "variable_importance")
svm_vp
plot(svm_vp) +
ggtitle("Mean-variable importance over 50 permutations", "")
请注意,在上面的食谱中,我删除了变量 pressure
并创建了一个新的分类变量 (some_new_group
)。
所以,我可以像这样从图中手动删除变量 pressure
some_new_group
:
plot(svm_vp %>% filter(variable != c("pressure", "some_new_group"))) +
ggtitle("Mean-variable importance over 50 permutations", "")
但是,当我 运行 explain_tidymodels()
或 model_parts()
时是否可以删除变量?
如果您的变量不是由您的 workflow()
处理的预测变量或结果(例如您删除的变量和您的分组变量),您需要确保 only pass outcomes and predictors to explain_tidymodels()
。您还需要使用 parsnip 模型构建解释器,而不是期望处理那些 non-outcome、non-predictor 变量的 workflow()
:
library(tidymodels)
# Data
data("PimaIndiansDiabetes", package = "mlbench")
dat <- PimaIndiansDiabetes
dat$some_new_group[1:384] <- "group 1"
dat$some_new_group[385:768] <- "group 2"
# Split
set.seed(123)
ind <- initial_split(dat)
dat_train <- training(ind)
dat_test <- testing(ind)
# Recipes
svm_rec <-
recipe(diabetes ~., data = dat_train) %>%
update_role(some_new_group, new_role = "group_var") %>%
step_rm(pressure) %>%
step_YeoJohnson(all_numeric_predictors())
# Model spec
svm_spec <-
svm_rbf() %>%
set_mode("classification") %>%
set_engine("kernlab")
# Train
svm_trained <-
workflow(svm_rec, svm_spec) %>%
fit(dat_train)
# Explainer
library(DALEXtra)
#> Loading required package: DALEX
#> Welcome to DALEX (version: 2.4.0).
#> Find examples and detailed introduction at: http://ema.drwhy.ai/
#>
#> Attaching package: 'DALEX'
#> The following object is masked from 'package:dplyr':
#>
#> explain
svm_exp <- explain_tidymodels(
extract_fit_parsnip(svm_trained),
data = svm_rec %>% prep() %>% bake(new_data = NULL, all_predictors()),
y = dat_train$diabetes %>% as.numeric(),
label = "SVM"
)
#> Preparation of a new explainer is initiated
#> -> model label : SVM
#> -> data : 576 rows 7 cols
#> -> data : tibble converted into a data.frame
#> -> target variable : 576 values
#> -> predict function : yhat.model_fit will be used ( default )
#> -> predicted values : No value for predict function target column. ( default )
#> -> model_info : package parsnip , ver. 0.2.1 , task classification ( default )
#> -> predicted values : numerical, min = 0.08057345 , mean = 0.3540662 , max = 0.9357536
#> -> residual function : difference between y and yhat ( default )
#> -> residuals : numerical, min = 0.1083522 , mean = 0.9948921 , max = 1.895405
#> A new explainer has been created!
# Variable importance
set.seed(123)
svm_vp <- model_parts(svm_exp, type = "variable_importance")
svm_vp
#> variable mean_dropout_loss label
#> 1 _full_model_ 0.6861190 SVM
#> 2 glucose 0.5919956 SVM
#> 3 mass 0.6673947 SVM
#> 4 pregnant 0.6700007 SVM
#> 5 age 0.6701185 SVM
#> 6 pedigree 0.6702812 SVM
#> 7 triceps 0.6760106 SVM
#> 8 insulin 0.6777355 SVM
#> 9 _baseline_ 0.5020752 SVM
plot(svm_vp) +
ggtitle("Mean-variable importance over 50 permutations", "")
由 reprex package (v2.0.1)
创建于 2022-05-03
如果您的工作流程中有这些不应用于解释的“额外”变量,那么您将需要做一些额外的工作,不能单独依赖 workflow()
。
我想从图中删除某些变量。
# Packages
library(tidymodels)
library(mlbench)
# Data
data("PimaIndiansDiabetes")
dat <- PimaIndiansDiabetes
dat$some_new_group[1:384] <- "group 1"
dat$some_new_group[385:768] <- "group 2"
# Split
set.seed(123)
ind <- initial_split(dat)
dat_train <- training(ind)
dat_test <- testing(ind)
# Recipes
svm_rec <-
recipe(diabetes ~., data = dat_train) %>%
update_role(some_new_group, new_role = "group_var") %>%
step_rm(pressure) %>%
step_YeoJohnson(all_numeric_predictors())
# Model spec
svm_spec <-
svm_rbf() %>%
set_mode("classification") %>%
set_engine("kernlab")
# Workflow
svm_wf <-
workflow() %>%
add_recipe(svm_rec) %>%
add_model(svm_spec)
# Train
svm_trained <-
svm_wf %>%
fit(dat_train)
# Explainer
library(DALEXtra)
svm_exp <- explain_tidymodels(svm_trained,
data = dat %>% select(-diabetes),
y = dat$diabetes %>% as.numeric(),
label = "SVM")
# Variable importance
set.seed(123)
svm_vp <- model_parts(svm_exp, type = "variable_importance")
svm_vp
plot(svm_vp) +
ggtitle("Mean-variable importance over 50 permutations", "")
请注意,在上面的食谱中,我删除了变量 pressure
并创建了一个新的分类变量 (some_new_group
)。
所以,我可以像这样从图中手动删除变量 pressure
some_new_group
:
plot(svm_vp %>% filter(variable != c("pressure", "some_new_group"))) +
ggtitle("Mean-variable importance over 50 permutations", "")
但是,当我 运行 explain_tidymodels()
或 model_parts()
时是否可以删除变量?
如果您的变量不是由您的 workflow()
处理的预测变量或结果(例如您删除的变量和您的分组变量),您需要确保 only pass outcomes and predictors to explain_tidymodels()
。您还需要使用 parsnip 模型构建解释器,而不是期望处理那些 non-outcome、non-predictor 变量的 workflow()
:
library(tidymodels)
# Data
data("PimaIndiansDiabetes", package = "mlbench")
dat <- PimaIndiansDiabetes
dat$some_new_group[1:384] <- "group 1"
dat$some_new_group[385:768] <- "group 2"
# Split
set.seed(123)
ind <- initial_split(dat)
dat_train <- training(ind)
dat_test <- testing(ind)
# Recipes
svm_rec <-
recipe(diabetes ~., data = dat_train) %>%
update_role(some_new_group, new_role = "group_var") %>%
step_rm(pressure) %>%
step_YeoJohnson(all_numeric_predictors())
# Model spec
svm_spec <-
svm_rbf() %>%
set_mode("classification") %>%
set_engine("kernlab")
# Train
svm_trained <-
workflow(svm_rec, svm_spec) %>%
fit(dat_train)
# Explainer
library(DALEXtra)
#> Loading required package: DALEX
#> Welcome to DALEX (version: 2.4.0).
#> Find examples and detailed introduction at: http://ema.drwhy.ai/
#>
#> Attaching package: 'DALEX'
#> The following object is masked from 'package:dplyr':
#>
#> explain
svm_exp <- explain_tidymodels(
extract_fit_parsnip(svm_trained),
data = svm_rec %>% prep() %>% bake(new_data = NULL, all_predictors()),
y = dat_train$diabetes %>% as.numeric(),
label = "SVM"
)
#> Preparation of a new explainer is initiated
#> -> model label : SVM
#> -> data : 576 rows 7 cols
#> -> data : tibble converted into a data.frame
#> -> target variable : 576 values
#> -> predict function : yhat.model_fit will be used ( default )
#> -> predicted values : No value for predict function target column. ( default )
#> -> model_info : package parsnip , ver. 0.2.1 , task classification ( default )
#> -> predicted values : numerical, min = 0.08057345 , mean = 0.3540662 , max = 0.9357536
#> -> residual function : difference between y and yhat ( default )
#> -> residuals : numerical, min = 0.1083522 , mean = 0.9948921 , max = 1.895405
#> A new explainer has been created!
# Variable importance
set.seed(123)
svm_vp <- model_parts(svm_exp, type = "variable_importance")
svm_vp
#> variable mean_dropout_loss label
#> 1 _full_model_ 0.6861190 SVM
#> 2 glucose 0.5919956 SVM
#> 3 mass 0.6673947 SVM
#> 4 pregnant 0.6700007 SVM
#> 5 age 0.6701185 SVM
#> 6 pedigree 0.6702812 SVM
#> 7 triceps 0.6760106 SVM
#> 8 insulin 0.6777355 SVM
#> 9 _baseline_ 0.5020752 SVM
plot(svm_vp) +
ggtitle("Mean-variable importance over 50 permutations", "")
由 reprex package (v2.0.1)
创建于 2022-05-03如果您的工作流程中有这些不应用于解释的“额外”变量,那么您将需要做一些额外的工作,不能单独依赖 workflow()
。