LASSO 回归 - 使用 tidymodels 在 glmnet 中强制变量
LASSO regression - Force variables in glmnet with tidymodels
我正在使用 tidymodels
和 glmnet
的 LASSO 回归进行特征选择。
可以使用 penalty.factors
参数强制 glmnet
中的变量(例如,参见 here and )。
是否可以使用 tidymodels
做同样的事情?
library(tidymodels)
library(vip)
library(forcats)
library(dplyr)
library(ggplot2)
library(data.table)
# Define data split
datasplit = rsample::initial_split(mtcars, prop=0.8)
data_training = rsample::training(datasplit)
data_testing = rsample::testing(datasplit)
# Model specifications - should penalty.factors go here?
model_spec = parsnip::linear_reg(penalty = tune::tune(),
mixture = 1) %>%
parsnip::set_engine("glmnet")
# Model recipe
rec = recipe(mpg ~ ., mtcars)
# Model workflow
wf = workflows::workflow() %>%
workflows::add_recipe(rec) %>%
workflows::add_model(model_spec)
# Resampling
data_resample = rsample::vfold_cv(data_training,
repeats = 3,
v = 2)
hyperparam_grid = dials::grid_regular(dials::penalty(),
levels = 100)
# Define metrics
metrics = yardstick::metric_set(yardstick::rsq,
yardstick::mape,
yardstick::mpe)
# Tune the model
tune_grid_results = tune::tune_grid(
wf,
resamples = data_resample,
grid = hyperparam_grid,
metrics = metrics
)
# Collect and finalise best model
selected_model = tune_grid_results %>%
tune::select_best("mape")
final_model = tune::finalize_workflow(wf, selected_model)
final_model_fit = final_model %>%
parsnip::fit(data_training) %>%
workflows::extract_fit_parsnip()
# Plot variables importance
t_importance = final_model_fit %>%
vip::vi(lambda = selected_model$penalty) %>%
dplyr::mutate(
Importance = Importance,
Variable = forcats::fct_reorder(Variable, Importance)
) %>%
data.table() %>%
setorder( - Importance)
t_importance %>%
ggplot(aes(x = Importance, y = Variable, fill = Sign)) +
geom_col() +
scale_x_continuous(expand = c(0, 0)) +
labs(y = NULL) +
theme_minimal()
由 reprex package (v2.0.1)
创建于 2022-03-14
如上评论所述,您可以pass engine-specific arguments like penalty.factor
in set_engine()
:
library(tidyverse)
library(tidymodels)
library(vip)
#>
#> Attaching package: 'vip'
#> The following object is masked from 'package:utils':
#>
#> vi
datasplit <- initial_split(mtcars, prop = 0.8)
car_train <- training(datasplit)
car_test <- testing(datasplit)
car_folds <- vfold_cv(car_train, repeats = 3, v = 2)
您可以将此处的 penalty.factor
作为 engine-specific 参数传递给模型规范:
glmnet_spec <- linear_reg(penalty = tune(), mixture = 1) %>%
set_engine("glmnet", penalty.factor = c(0, rep(1, 7), 0, 0))
car_wf <- workflow(mpg ~ ., glmnet_spec)
glmnet_res <- tune_grid(car_wf, resamples = car_folds, grid = 5)
glmnet_res
#> # Tuning results
#> # 2-fold cross-validation repeated 3 times
#> # A tibble: 6 × 5
#> splits id id2 .metrics .notes
#> <list> <chr> <chr> <list> <list>
#> 1 <split [12/13]> Repeat1 Fold1 <tibble [10 × 5]> <tibble [0 × 3]>
#> 2 <split [13/12]> Repeat1 Fold2 <tibble [10 × 5]> <tibble [0 × 3]>
#> 3 <split [12/13]> Repeat2 Fold1 <tibble [10 × 5]> <tibble [0 × 3]>
#> 4 <split [13/12]> Repeat2 Fold2 <tibble [10 × 5]> <tibble [0 × 3]>
#> 5 <split [12/13]> Repeat3 Fold1 <tibble [10 × 5]> <tibble [0 × 3]>
#> 6 <split [13/12]> Repeat3 Fold2 <tibble [10 × 5]> <tibble [0 × 3]>
best_penalty <- select_best(glmnet_res, "rmse")
final_fit <- car_wf %>%
finalize_workflow(best_penalty) %>%
fit(data = car_train) %>%
extract_fit_parsnip()
final_fit %>%
vi(lambda = best_penalty$penalty) %>%
mutate(Variable = fct_reorder(Variable, Importance)) %>%
ggplot(aes(x = Importance, y = Variable, fill = Sign)) +
geom_col() +
scale_x_continuous(expand = c(0, 0)) +
labs(y = NULL) +
theme_minimal()
由 reprex package (v2.0.1)
创建于 2022-03-14
这确实需要您在创建模型规范时知道预测变量的数量,这对于包含许多特征工程步骤的复杂配方来说可能具有挑战性。
我正在使用 tidymodels
和 glmnet
的 LASSO 回归进行特征选择。
可以使用 penalty.factors
参数强制 glmnet
中的变量(例如,参见 here and
是否可以使用 tidymodels
做同样的事情?
library(tidymodels)
library(vip)
library(forcats)
library(dplyr)
library(ggplot2)
library(data.table)
# Define data split
datasplit = rsample::initial_split(mtcars, prop=0.8)
data_training = rsample::training(datasplit)
data_testing = rsample::testing(datasplit)
# Model specifications - should penalty.factors go here?
model_spec = parsnip::linear_reg(penalty = tune::tune(),
mixture = 1) %>%
parsnip::set_engine("glmnet")
# Model recipe
rec = recipe(mpg ~ ., mtcars)
# Model workflow
wf = workflows::workflow() %>%
workflows::add_recipe(rec) %>%
workflows::add_model(model_spec)
# Resampling
data_resample = rsample::vfold_cv(data_training,
repeats = 3,
v = 2)
hyperparam_grid = dials::grid_regular(dials::penalty(),
levels = 100)
# Define metrics
metrics = yardstick::metric_set(yardstick::rsq,
yardstick::mape,
yardstick::mpe)
# Tune the model
tune_grid_results = tune::tune_grid(
wf,
resamples = data_resample,
grid = hyperparam_grid,
metrics = metrics
)
# Collect and finalise best model
selected_model = tune_grid_results %>%
tune::select_best("mape")
final_model = tune::finalize_workflow(wf, selected_model)
final_model_fit = final_model %>%
parsnip::fit(data_training) %>%
workflows::extract_fit_parsnip()
# Plot variables importance
t_importance = final_model_fit %>%
vip::vi(lambda = selected_model$penalty) %>%
dplyr::mutate(
Importance = Importance,
Variable = forcats::fct_reorder(Variable, Importance)
) %>%
data.table() %>%
setorder( - Importance)
t_importance %>%
ggplot(aes(x = Importance, y = Variable, fill = Sign)) +
geom_col() +
scale_x_continuous(expand = c(0, 0)) +
labs(y = NULL) +
theme_minimal()
由 reprex package (v2.0.1)
创建于 2022-03-14如上评论所述,您可以pass engine-specific arguments like penalty.factor
in set_engine()
:
library(tidyverse)
library(tidymodels)
library(vip)
#>
#> Attaching package: 'vip'
#> The following object is masked from 'package:utils':
#>
#> vi
datasplit <- initial_split(mtcars, prop = 0.8)
car_train <- training(datasplit)
car_test <- testing(datasplit)
car_folds <- vfold_cv(car_train, repeats = 3, v = 2)
您可以将此处的 penalty.factor
作为 engine-specific 参数传递给模型规范:
glmnet_spec <- linear_reg(penalty = tune(), mixture = 1) %>%
set_engine("glmnet", penalty.factor = c(0, rep(1, 7), 0, 0))
car_wf <- workflow(mpg ~ ., glmnet_spec)
glmnet_res <- tune_grid(car_wf, resamples = car_folds, grid = 5)
glmnet_res
#> # Tuning results
#> # 2-fold cross-validation repeated 3 times
#> # A tibble: 6 × 5
#> splits id id2 .metrics .notes
#> <list> <chr> <chr> <list> <list>
#> 1 <split [12/13]> Repeat1 Fold1 <tibble [10 × 5]> <tibble [0 × 3]>
#> 2 <split [13/12]> Repeat1 Fold2 <tibble [10 × 5]> <tibble [0 × 3]>
#> 3 <split [12/13]> Repeat2 Fold1 <tibble [10 × 5]> <tibble [0 × 3]>
#> 4 <split [13/12]> Repeat2 Fold2 <tibble [10 × 5]> <tibble [0 × 3]>
#> 5 <split [12/13]> Repeat3 Fold1 <tibble [10 × 5]> <tibble [0 × 3]>
#> 6 <split [13/12]> Repeat3 Fold2 <tibble [10 × 5]> <tibble [0 × 3]>
best_penalty <- select_best(glmnet_res, "rmse")
final_fit <- car_wf %>%
finalize_workflow(best_penalty) %>%
fit(data = car_train) %>%
extract_fit_parsnip()
final_fit %>%
vi(lambda = best_penalty$penalty) %>%
mutate(Variable = fct_reorder(Variable, Importance)) %>%
ggplot(aes(x = Importance, y = Variable, fill = Sign)) +
geom_col() +
scale_x_continuous(expand = c(0, 0)) +
labs(y = NULL) +
theme_minimal()
由 reprex package (v2.0.1)
创建于 2022-03-14这确实需要您在创建模型规范时知道预测变量的数量,这对于包含许多特征工程步骤的复杂配方来说可能具有挑战性。