R - 来自工作流的部分依赖图
R - Partial dependence plots from workflow
我创建了以下方法来预测我在 R 中的随机森林:
set.seed(123456)
cv_folds <- Data_train %>% vfold_cv(v = 4, strata = Lead_week)
# Create a recipe
rf_mod_recipe <- recipe(Lead_week ~ Jaar + Aantal + Verzekering + Leeftijd + Retentie +
Aantal_proeven + Geslacht + FLG_ADVERTISING + FLG_MAIL +
FLG_PHONE + FLG_EMAIL + Proef1 + Proef2 + Regio +
Month + AC,
data = Data_train) %>%
step_normalize(Leeftijd)
# Specify the recipe
rf_mod <- rand_forest(mtry = tune(), min_n = tune(), trees = 200) %>%
set_mode("regression") %>%
set_engine("ranger", importance = "permutation")
# Create a workflow
rf_mod_workflow <- workflow() %>%
add_model(rf_mod) %>%
add_recipe(rf_mod_recipe)
rf_mod_workflow
# State our error metrics
class_metrics <- metric_set(rmse, mae)
rf_grid <- grid_regular(
mtry(range = c(5, 15)),
min_n(range = c(10, 200)),
levels = 5
)
rf_grid
# Train the model
set.seed(654321)
rf_tune_res <- tune_grid(
rf_mod_workflow,
resamples = cv_folds,
grid = rf_grid,
metrics = class_metrics
)
# Collect the optimal hyperparameters
rf_tune_res %>%
collect_metrics()
# Select the best number of mtry
best_rmse <- select_best(rf_tune_res, "rmse")
rf_final_wf <- finalize_workflow(rf_mod_workflow, best_rmse)
rf_final_wf
# Create a workflow
rf_mod_workflow <- workflow() %>%
add_model(rf_mod) %>%
add_recipe(rf_mod_recipe)
rf_mod_workflow
predict(rf_final_wf, grid) %>%
bind_cols(rf_mod_recipe %>% select(AC)) %>%
ggplot(aes(y = .pred, x = AC)) +
geom_path()
检索样本内性能后,我使用工作流对保留数据进行预测。
# Finalise the workflow
set.seed(56789)
rf_final_fit <- rf_final_wf %>%
last_fit(splits, metrics = class_metrics)
# Collect predictions
summary_rf <- rf_final_fit %>%
collect_predictions()
summary(summary_rf$.pred)
# Collect metrics
rf_final_fit %>%
collect_metrics()
所以我使用交叉验证来微调并最终测试保留数据。但是,如何获得 'open the black box' 的部分依赖图?
我们建议将 DALEX 用于此类模型可解释性任务,因为 great support for tidymodels。
获得最终拟合模型(例如随机森林)后,您需要:
- 创建 DALEX 解释器
- 计算 PDP
library(tidymodels)
#> Registered S3 method overwritten by 'tune':
#> method from
#> required_pkgs.model_spec parsnip
library(DALEXtra)
#> Loading required package: DALEX
#> Welcome to DALEX (version: 2.2.0).
#> Find examples and detailed introduction at: http://ema.drwhy.ai/
#> Additional features will be available after installation of: ggpubr.
#> Use 'install_dependencies()' to get all suggested dependencies
#>
#> Attaching package: 'DALEX'
#> The following object is masked from 'package:dplyr':
#>
#> explain
data(ames)
ames_train <- ames %>%
transmute(Sale_Price = log10(Sale_Price),
Gr_Liv_Area = as.numeric(Gr_Liv_Area),
Year_Built, Bldg_Type)
rf_model <-
rand_forest(trees = 1000) %>%
set_engine("ranger") %>%
set_mode("regression")
rf_wflow <-
workflow() %>%
add_formula(
Sale_Price ~ Gr_Liv_Area + Year_Built + Bldg_Type) %>%
add_model(rf_model)
rf_fit <- rf_wflow %>% fit(data = ames_train)
explainer_rf <- explain_tidymodels(
rf_fit,
data = dplyr::select(ames_train, -Sale_Price),
y = ames_train$Sale_Price,
label = "random forest"
)
#> Preparation of a new explainer is initiated
#> -> model label : random forest
#> -> data : 2930 rows 3 cols
#> -> data : tibble converted into a data.frame
#> -> target variable : 2930 values
#> -> predict function : yhat.workflow will be used ( [33m default [39m )
#> -> predicted values : No value for predict function target column. ( [33m default [39m )
#> -> model_info : package tidymodels , ver. 0.1.3 , task regression ( [33m default [39m )
#> -> predicted values : numerical, min = 4.896018 , mean = 5.220595 , max = 5.518857
#> -> residual function : difference between y and yhat ( [33m default [39m )
#> -> residuals : numerical, min = -0.8083636 , mean = 4.509735e-05 , max = 0.3590898
#> [32m A new explainer has been created! [39m
pdp_rf <- model_profile(explainer_rf, N = NULL,
variables = "Gr_Liv_Area", groups = "Bldg_Type")
as_tibble(pdp_rf$agr_profiles) %>%
mutate(`_label_` = stringr::str_remove(`_label_`, "random forest_")) %>%
ggplot(aes(`_x_`, `_yhat_`, color = `_label_`)) +
geom_line(size = 1.2, alpha = 0.8) +
labs(x = "Gross living area",
y = "Sale Price (log)",
color = NULL,
title = "Partial dependence profile for Ames housing sales",
subtitle = "Predictions from a random forest model")
由 reprex package (v2.0.0)
于 2021-05-27 创建
看来我应该把 x 轴放在对数刻度上。
您可以调用 plot(pdp_rf)
以使用 DALEX 的默认绘图方法,但我在这里展示了如何使用底层计算的 PDP 制作更自定义的绘图。
我创建了以下方法来预测我在 R 中的随机森林:
set.seed(123456)
cv_folds <- Data_train %>% vfold_cv(v = 4, strata = Lead_week)
# Create a recipe
rf_mod_recipe <- recipe(Lead_week ~ Jaar + Aantal + Verzekering + Leeftijd + Retentie +
Aantal_proeven + Geslacht + FLG_ADVERTISING + FLG_MAIL +
FLG_PHONE + FLG_EMAIL + Proef1 + Proef2 + Regio +
Month + AC,
data = Data_train) %>%
step_normalize(Leeftijd)
# Specify the recipe
rf_mod <- rand_forest(mtry = tune(), min_n = tune(), trees = 200) %>%
set_mode("regression") %>%
set_engine("ranger", importance = "permutation")
# Create a workflow
rf_mod_workflow <- workflow() %>%
add_model(rf_mod) %>%
add_recipe(rf_mod_recipe)
rf_mod_workflow
# State our error metrics
class_metrics <- metric_set(rmse, mae)
rf_grid <- grid_regular(
mtry(range = c(5, 15)),
min_n(range = c(10, 200)),
levels = 5
)
rf_grid
# Train the model
set.seed(654321)
rf_tune_res <- tune_grid(
rf_mod_workflow,
resamples = cv_folds,
grid = rf_grid,
metrics = class_metrics
)
# Collect the optimal hyperparameters
rf_tune_res %>%
collect_metrics()
# Select the best number of mtry
best_rmse <- select_best(rf_tune_res, "rmse")
rf_final_wf <- finalize_workflow(rf_mod_workflow, best_rmse)
rf_final_wf
# Create a workflow
rf_mod_workflow <- workflow() %>%
add_model(rf_mod) %>%
add_recipe(rf_mod_recipe)
rf_mod_workflow
predict(rf_final_wf, grid) %>%
bind_cols(rf_mod_recipe %>% select(AC)) %>%
ggplot(aes(y = .pred, x = AC)) +
geom_path()
检索样本内性能后,我使用工作流对保留数据进行预测。
# Finalise the workflow
set.seed(56789)
rf_final_fit <- rf_final_wf %>%
last_fit(splits, metrics = class_metrics)
# Collect predictions
summary_rf <- rf_final_fit %>%
collect_predictions()
summary(summary_rf$.pred)
# Collect metrics
rf_final_fit %>%
collect_metrics()
所以我使用交叉验证来微调并最终测试保留数据。但是,如何获得 'open the black box' 的部分依赖图?
我们建议将 DALEX 用于此类模型可解释性任务,因为 great support for tidymodels。
获得最终拟合模型(例如随机森林)后,您需要:
- 创建 DALEX 解释器
- 计算 PDP
library(tidymodels)
#> Registered S3 method overwritten by 'tune':
#> method from
#> required_pkgs.model_spec parsnip
library(DALEXtra)
#> Loading required package: DALEX
#> Welcome to DALEX (version: 2.2.0).
#> Find examples and detailed introduction at: http://ema.drwhy.ai/
#> Additional features will be available after installation of: ggpubr.
#> Use 'install_dependencies()' to get all suggested dependencies
#>
#> Attaching package: 'DALEX'
#> The following object is masked from 'package:dplyr':
#>
#> explain
data(ames)
ames_train <- ames %>%
transmute(Sale_Price = log10(Sale_Price),
Gr_Liv_Area = as.numeric(Gr_Liv_Area),
Year_Built, Bldg_Type)
rf_model <-
rand_forest(trees = 1000) %>%
set_engine("ranger") %>%
set_mode("regression")
rf_wflow <-
workflow() %>%
add_formula(
Sale_Price ~ Gr_Liv_Area + Year_Built + Bldg_Type) %>%
add_model(rf_model)
rf_fit <- rf_wflow %>% fit(data = ames_train)
explainer_rf <- explain_tidymodels(
rf_fit,
data = dplyr::select(ames_train, -Sale_Price),
y = ames_train$Sale_Price,
label = "random forest"
)
#> Preparation of a new explainer is initiated
#> -> model label : random forest
#> -> data : 2930 rows 3 cols
#> -> data : tibble converted into a data.frame
#> -> target variable : 2930 values
#> -> predict function : yhat.workflow will be used ( [33m default [39m )
#> -> predicted values : No value for predict function target column. ( [33m default [39m )
#> -> model_info : package tidymodels , ver. 0.1.3 , task regression ( [33m default [39m )
#> -> predicted values : numerical, min = 4.896018 , mean = 5.220595 , max = 5.518857
#> -> residual function : difference between y and yhat ( [33m default [39m )
#> -> residuals : numerical, min = -0.8083636 , mean = 4.509735e-05 , max = 0.3590898
#> [32m A new explainer has been created! [39m
pdp_rf <- model_profile(explainer_rf, N = NULL,
variables = "Gr_Liv_Area", groups = "Bldg_Type")
as_tibble(pdp_rf$agr_profiles) %>%
mutate(`_label_` = stringr::str_remove(`_label_`, "random forest_")) %>%
ggplot(aes(`_x_`, `_yhat_`, color = `_label_`)) +
geom_line(size = 1.2, alpha = 0.8) +
labs(x = "Gross living area",
y = "Sale Price (log)",
color = NULL,
title = "Partial dependence profile for Ames housing sales",
subtitle = "Predictions from a random forest model")
由 reprex package (v2.0.0)
于 2021-05-27 创建看来我应该把 x 轴放在对数刻度上。
您可以调用 plot(pdp_rf)
以使用 DALEX 的默认绘图方法,但我在这里展示了如何使用底层计算的 PDP 制作更自定义的绘图。