调整后的预测 Tidymodels
Adjusted Predictions Tidymodels
有人知道如何在 marginaleffects()
包中使用 predictions()
和 tidymodels
吗?在这个玩具示例中,我想获得变量 state
的预测值,同时将所有其他变量保持在其基本水平或平均值。
library(liver)
library(tidymodels)
library(marginaleffects)
df_churn <- data.frame(churn)
# Create data split object
churn_split <- initial_split(df_churn, prop = 0.75,
strata = churn)
# Create the training data
churn_train <- churn_split %>%
training()
# Create the test data
churn_test <- churn_split %>%
testing()
lr_mod <-
logistic_reg(penalty = tune(), mixture = 1) %>% # penalty = lambda. mixture = alpha
set_engine("glmnet") %>%
set_mode("classification")
# pre-process recipe
churn_recipe <- recipe(churn ~ .,
data = churn_train) %>%
step_corr(all_numeric(), threshold = 0.9) %>%
step_normalize(all_numeric()) %>%
step_dummy(all_nominal(), -all_outcomes())
# model + recipe = workflow
churn_wkfl <- workflow() %>%
add_model(lr_mod) %>%
add_recipe(churn_recipe)
# cv
set.seed(1)
churn_folds <- vfold_cv(churn_train,
v = 10,
strata = churn)
# grid
set.seed(1)
glmnet_tuning <- churn_wkfl %>%
tune_grid(resamples = churn_folds,
grid = 25, # let the model find the best hyperparameters
metrics = metric_set(roc_auc))
# select the best model
best_glmnet_model <- glmnet_tuning %>%
select_best(metric = 'roc_auc')
# finalize the workflow and try to get adjusted predictions
# This does not work
final_churn_wkfl <- churn_wkfl %>%
finalize_workflow(best_glmnet_model) %>%
fit(churn_train) %>%
tidy() %>%
predictions(variables = c("state"))
不幸的是,glmnet 不是 one of the supported models for marginaleffects。
您可以将其切换为受支持的模型之一(如常规 glm()
)和此 will work using extract_fit_engine()
。
library(tidymodels)
library(marginaleffects)
data("mlc_churn")
set.seed(123)
churn_split <- initial_split(mlc_churn, prop = 0.75, strata = churn)
churn_train <- training(churn_split)
churn_test <- testing(churn_split)
churn_recipe <- recipe(churn ~ .,
data = churn_train) %>%
step_corr(all_numeric(), threshold = 0.9) %>%
step_normalize(all_numeric()) %>%
step_dummy(all_nominal(), -all_outcomes())
# model + recipe = workflow
churn_wkfl <- workflow(churn_recipe, logistic_reg())
# finalize the workflow and try to get adjusted predictions
churn_wkfl %>%
fit(churn_train) %>%
extract_fit_engine() %>%
predictions(variables = c("total_intl_calls")) %>%
as_tibble()
#> # A tibble: 5 × 71
#> rowid type predicted std.error conf.low conf.high account_length
#> <int> <chr> <dbl> <dbl> <dbl> <dbl> <dbl>
#> 1 1 response 0.895 0.0119 0.870 0.916 1.76e-17
#> 2 2 response 0.917 0.00620 0.904 0.928 1.76e-17
#> 3 3 response 0.923 0.00543 0.912 0.933 1.76e-17
#> 4 4 response 0.934 0.00549 0.923 0.944 1.76e-17
#> 5 5 response 0.977 0.00840 0.953 0.989 1.76e-17
#> # … with 64 more variables: number_vmail_messages <dbl>,
#> # total_day_minutes <dbl>, total_day_calls <dbl>, total_eve_minutes <dbl>,
#> # total_eve_calls <dbl>, total_night_calls <dbl>, total_night_charge <dbl>,
#> # total_intl_minutes <dbl>, number_customer_service_calls <dbl>,
#> # state_AL <dbl>, state_AR <dbl>, state_AZ <dbl>, state_CA <dbl>,
#> # state_CO <dbl>, state_CT <dbl>, state_DC <dbl>, state_DE <dbl>,
#> # state_FL <dbl>, state_GA <dbl>, state_HI <dbl>, state_IA <dbl>, …
由 reprex package (v2.0.1)
于 2022-03-25 创建
请注意,我没有使用 variables = c("state")
而是替换了其中一个连续的数字预测变量。
有人知道如何在 marginaleffects()
包中使用 predictions()
和 tidymodels
吗?在这个玩具示例中,我想获得变量 state
的预测值,同时将所有其他变量保持在其基本水平或平均值。
library(liver)
library(tidymodels)
library(marginaleffects)
df_churn <- data.frame(churn)
# Create data split object
churn_split <- initial_split(df_churn, prop = 0.75,
strata = churn)
# Create the training data
churn_train <- churn_split %>%
training()
# Create the test data
churn_test <- churn_split %>%
testing()
lr_mod <-
logistic_reg(penalty = tune(), mixture = 1) %>% # penalty = lambda. mixture = alpha
set_engine("glmnet") %>%
set_mode("classification")
# pre-process recipe
churn_recipe <- recipe(churn ~ .,
data = churn_train) %>%
step_corr(all_numeric(), threshold = 0.9) %>%
step_normalize(all_numeric()) %>%
step_dummy(all_nominal(), -all_outcomes())
# model + recipe = workflow
churn_wkfl <- workflow() %>%
add_model(lr_mod) %>%
add_recipe(churn_recipe)
# cv
set.seed(1)
churn_folds <- vfold_cv(churn_train,
v = 10,
strata = churn)
# grid
set.seed(1)
glmnet_tuning <- churn_wkfl %>%
tune_grid(resamples = churn_folds,
grid = 25, # let the model find the best hyperparameters
metrics = metric_set(roc_auc))
# select the best model
best_glmnet_model <- glmnet_tuning %>%
select_best(metric = 'roc_auc')
# finalize the workflow and try to get adjusted predictions
# This does not work
final_churn_wkfl <- churn_wkfl %>%
finalize_workflow(best_glmnet_model) %>%
fit(churn_train) %>%
tidy() %>%
predictions(variables = c("state"))
不幸的是,glmnet 不是 one of the supported models for marginaleffects。
您可以将其切换为受支持的模型之一(如常规 glm()
)和此 will work using extract_fit_engine()
。
library(tidymodels)
library(marginaleffects)
data("mlc_churn")
set.seed(123)
churn_split <- initial_split(mlc_churn, prop = 0.75, strata = churn)
churn_train <- training(churn_split)
churn_test <- testing(churn_split)
churn_recipe <- recipe(churn ~ .,
data = churn_train) %>%
step_corr(all_numeric(), threshold = 0.9) %>%
step_normalize(all_numeric()) %>%
step_dummy(all_nominal(), -all_outcomes())
# model + recipe = workflow
churn_wkfl <- workflow(churn_recipe, logistic_reg())
# finalize the workflow and try to get adjusted predictions
churn_wkfl %>%
fit(churn_train) %>%
extract_fit_engine() %>%
predictions(variables = c("total_intl_calls")) %>%
as_tibble()
#> # A tibble: 5 × 71
#> rowid type predicted std.error conf.low conf.high account_length
#> <int> <chr> <dbl> <dbl> <dbl> <dbl> <dbl>
#> 1 1 response 0.895 0.0119 0.870 0.916 1.76e-17
#> 2 2 response 0.917 0.00620 0.904 0.928 1.76e-17
#> 3 3 response 0.923 0.00543 0.912 0.933 1.76e-17
#> 4 4 response 0.934 0.00549 0.923 0.944 1.76e-17
#> 5 5 response 0.977 0.00840 0.953 0.989 1.76e-17
#> # … with 64 more variables: number_vmail_messages <dbl>,
#> # total_day_minutes <dbl>, total_day_calls <dbl>, total_eve_minutes <dbl>,
#> # total_eve_calls <dbl>, total_night_calls <dbl>, total_night_charge <dbl>,
#> # total_intl_minutes <dbl>, number_customer_service_calls <dbl>,
#> # state_AL <dbl>, state_AR <dbl>, state_AZ <dbl>, state_CA <dbl>,
#> # state_CO <dbl>, state_CT <dbl>, state_DC <dbl>, state_DE <dbl>,
#> # state_FL <dbl>, state_GA <dbl>, state_HI <dbl>, state_IA <dbl>, …
由 reprex package (v2.0.1)
于 2022-03-25 创建请注意,我没有使用 variables = c("state")
而是替换了其中一个连续的数字预测变量。