last_fit 中的 Tidymodels/XGBoost 错误与 rsplit 值
Tidymodels / XGBoost error in last_fit with rsplit value
我正在尝试按照此处的教程进行操作 - https://juliasilge.com/blog/xgboost-tune-volleyball/
我在关于大湖捕鱼的最新 Tidy Tuesday 数据集上使用它 - 试图根据许多其他值来预测机构。
除了最后一行出现以下错误外,下面的所有代码都有效:
> final_res <- last_fit(final_xgb, stock_folds)
Error: Each element of `splits` must be an `rsplit` object.
我搜索了那个错误并来到了这个页面 - https://github.com/tidymodels/rsample/issues/175
该站点将其称为错误并且似乎已修复 - 但它是 initial_time_split,而不是我正在使用的 initial_split。我宁愿不更改它,因为那样我将不得不重新运行耗时 9 小时的 xgboost。这里出了什么问题?
# Setup ----
library(tidyverse)
library(tidymodels)
stocked <- readr::read_csv('https://raw.githubusercontent.com/rfordatascience/tidytuesday/master/data/2021/2021-06-08/stocked.csv')
stocked_modeling <- stocked %>%
mutate(AGENCY = case_when(
AGENCY != "OMNR" ~ "other",
TRUE ~ AGENCY
)) %>%
select(-SID, -MONTH, -DAY, -LATITUDE, -LONGITUDE, -GRID, -STRAIN, -AGEMONTH,
-MARK_EFF, -TAG_NO, -TAG_RET, -LENGTH, -WEIGHT, - CONDITION, -LOT_CODE,
-NOTES, - VALIDATION, -LS_MGMT, -STAT_DIST, -ST_SITE, -YEAR_CLASS, -STOCK_METH) %>%
mutate_if(is.character, factor) %>%
drop_na()
# Start making model ----
set.seed(123)
stock_split <- initial_split(stocked_modeling, strata = AGENCY)
stock_train <- training(stock_split)
stock_test <- testing(stock_split)
xgb_spec <- boost_tree(
trees = 1000,
tree_depth = tune(), min_n = tune(), loss_reduction = tune(),
sample_size = tune(), mtry = tune(),
learn_rate = tune()
) %>%
set_engine("xgboost") %>%
set_mode("classification")
xgb_grid <- grid_latin_hypercube(
tree_depth(),
min_n(),
loss_reduction(),
sample_size = sample_prop(),
finalize(mtry(), stock_train),
learn_rate(),
size = 20
)
xgb_workflow <- workflow() %>%
add_formula(AGENCY ~ .) %>%
add_model(xgb_spec)
set.seed(123)
stock_folds <- vfold_cv(stock_train, strata = AGENCY)
doParallel::registerDoParallel()
# BEWARE, THIS CODE BELOW TOOK 9 HOURS TO RUN
set.seed(234)
xgb_res <- tune_grid(
xgb_workflow,
resamples = stock_folds,
grid = xgb_grid,
control = control_grid(save_pred = TRUE)
)
# Explore results
best_auc <- select_best(xgb_res, "roc_auc")
final_xgb <- finalize_workflow(
xgb_workflow,
best_auc)
final_res <- last_fit(final_xgb, stock_folds)
如果我们查看 last_fit() 的文档,我们会发现 split
必须是
An rsplit object created from `rsample::initial_split().
您不小心将交叉验证折叠对象 stock_folds
传递给了 split
但您应该传递 rsplit
对象 stock_split
而不是
final_res <- last_fit(final_xgb, stock_split)
我正在尝试按照此处的教程进行操作 - https://juliasilge.com/blog/xgboost-tune-volleyball/
我在关于大湖捕鱼的最新 Tidy Tuesday 数据集上使用它 - 试图根据许多其他值来预测机构。
除了最后一行出现以下错误外,下面的所有代码都有效:
> final_res <- last_fit(final_xgb, stock_folds)
Error: Each element of `splits` must be an `rsplit` object.
我搜索了那个错误并来到了这个页面 - https://github.com/tidymodels/rsample/issues/175 该站点将其称为错误并且似乎已修复 - 但它是 initial_time_split,而不是我正在使用的 initial_split。我宁愿不更改它,因为那样我将不得不重新运行耗时 9 小时的 xgboost。这里出了什么问题?
# Setup ----
library(tidyverse)
library(tidymodels)
stocked <- readr::read_csv('https://raw.githubusercontent.com/rfordatascience/tidytuesday/master/data/2021/2021-06-08/stocked.csv')
stocked_modeling <- stocked %>%
mutate(AGENCY = case_when(
AGENCY != "OMNR" ~ "other",
TRUE ~ AGENCY
)) %>%
select(-SID, -MONTH, -DAY, -LATITUDE, -LONGITUDE, -GRID, -STRAIN, -AGEMONTH,
-MARK_EFF, -TAG_NO, -TAG_RET, -LENGTH, -WEIGHT, - CONDITION, -LOT_CODE,
-NOTES, - VALIDATION, -LS_MGMT, -STAT_DIST, -ST_SITE, -YEAR_CLASS, -STOCK_METH) %>%
mutate_if(is.character, factor) %>%
drop_na()
# Start making model ----
set.seed(123)
stock_split <- initial_split(stocked_modeling, strata = AGENCY)
stock_train <- training(stock_split)
stock_test <- testing(stock_split)
xgb_spec <- boost_tree(
trees = 1000,
tree_depth = tune(), min_n = tune(), loss_reduction = tune(),
sample_size = tune(), mtry = tune(),
learn_rate = tune()
) %>%
set_engine("xgboost") %>%
set_mode("classification")
xgb_grid <- grid_latin_hypercube(
tree_depth(),
min_n(),
loss_reduction(),
sample_size = sample_prop(),
finalize(mtry(), stock_train),
learn_rate(),
size = 20
)
xgb_workflow <- workflow() %>%
add_formula(AGENCY ~ .) %>%
add_model(xgb_spec)
set.seed(123)
stock_folds <- vfold_cv(stock_train, strata = AGENCY)
doParallel::registerDoParallel()
# BEWARE, THIS CODE BELOW TOOK 9 HOURS TO RUN
set.seed(234)
xgb_res <- tune_grid(
xgb_workflow,
resamples = stock_folds,
grid = xgb_grid,
control = control_grid(save_pred = TRUE)
)
# Explore results
best_auc <- select_best(xgb_res, "roc_auc")
final_xgb <- finalize_workflow(
xgb_workflow,
best_auc)
final_res <- last_fit(final_xgb, stock_folds)
如果我们查看 last_fit() 的文档,我们会发现 split
必须是
An rsplit object created from `rsample::initial_split().
您不小心将交叉验证折叠对象 stock_folds
传递给了 split
但您应该传递 rsplit
对象 stock_split
而不是
final_res <- last_fit(final_xgb, stock_split)