如何在 R 插入符号和示例中使用相同的交叉验证集

How can I use the same crossvalidation sets in R caret and rsamples

我正在尝试通过将 caret::train() 代码转换为 tidymodels 工作流程来学习 tidymodels 生态系统。我得到的差异我认为是 caretrsample 中重采样算法的副产品。一位同事写了一篇要点,显示了具有相同种子的数据集的差异:https://gist.github.com/bradleyboehmke/7794b79a07afb443da11d930ff84bed7

您可以在此处看到简单模型中的细微差异(我认为我的编码是相同的):

library(caret)
library(tidyverse)
library(tidymodels)
data(ames)

set.seed(123)
(cv_model1 <- train(
  form = Sale_Price ~ Gr_Liv_Area, 
  data = ames,
  method = "lm",
  trControl = trainControl(method="cv", number = 10)
))

对比

set.seed(123)
folds <- vfold_cv(ames, v = 10)

the_lm_model <- 
  linear_reg() %>% 
  set_engine("lm")

the_rec <- 
  recipe(Sale_Price ~ Gr_Liv_Area, data = ames)

the_workflow <- 
  workflow() %>% 
  add_recipe(the_rec) %>% 
  add_model(the_lm_model) 


the_results <- 
  fit_resamples(the_workflow, folds)

collect_metrics(the_results)

tidymodel 工作流程(通常使用 rsample::vfold_cv() 创建)中是否有直接使用 caret 重采样(来自 caret::createFolds())的方法?我是希望如果我能弄清楚这个细节,我可以用新的生态系统复制复杂的旧代码(用于教学)。

编辑。感谢 Julia Silge 的评论。

函数 rsample2caret() and caret2rsample()

可用于在格式之间转换重采样对象。

下面的答案对于从任意格式转换为 rsample 很有用。

旧答案

这是一种将 caret::createFolds 的输出转换为 rsample

的方法
library(caret)
library(tidyverse)
library(tidymodels)

data(ames)

#create train folds
set.seed(123)
folds_train <- caret::createFolds(ames$Sale_Price, returnTrain = TRUE, k = 10)

#get test indexes
folds_test <- lapply(folds_train, function(x) setdiff(seq_along(ames$Sale_Price), x))

结合训练和测试索引创建分析列表和评估列表,如 manual_rset

中所述
rsplit <- map2(folds_train,
               folds_test,
               function(x,y) list(analysis = x, assessment = y))

splits <- lapply(rsplit, make_splits, data = ames)
splits <- manual_rset(splits, names(splits))
> splits
# Manual resampling 
# A tibble: 10 x 2
   splits             id    
   <named list>       <chr> 
 1 <split [2637/293]> Fold01
 2 <split [2638/292]> Fold02
 3 <split [2637/293]> Fold03
 4 <split [2637/293]> Fold04
 5 <split [2638/292]> Fold05
 6 <split [2637/293]> Fold06
 7 <split [2637/293]> Fold07
 8 <split [2636/294]> Fold08
 9 <split [2636/294]> Fold09
10 <split [2637/293]> Fold10

检查结果是否相同:

set.seed(123)
cv_model1 <- train(
  form = Sale_Price ~ Gr_Liv_Area, 
  data = ames,
  method = "lm",
  trControl = trainControl(index= folds_train))
> cv_model1
Linear Regression 

2930 samples
   1 predictor

No pre-processing
Resampling: Bootstrapped (10 reps) 
Summary of sample sizes: 2637, 2638, 2637, 2637, 2638, 2637, ... 
Resampling results:

  RMSE      Rsquared   MAE     
  56364.67  0.5066935  38575.21

Tuning parameter 'intercept' was held constant at a value of TRUE

the_lm_model <- 
  linear_reg() %>% 
  set_engine("lm")

the_rec <- 
  recipe(Sale_Price ~ Gr_Liv_Area, data = ames)

the_workflow <- 
  workflow() %>% 
  add_recipe(the_rec) %>% 
  add_model(the_lm_model) 

set.seed(123)
the_results <- 
  fit_resamples(the_workflow, splits)

collect_metrics(the_results)
> collect_metrics(the_results)
# A tibble: 2 x 6
  .metric .estimator      mean     n   std_err .config             
  <chr>   <chr>          <dbl> <int>     <dbl> <chr>               
1 rmse    standard   56365.       10 1782.     Preprocessor1_Model1
2 rsq     standard       0.507    10    0.0220 Preprocessor1_Model1

all.equal(
cv_model1$results$RMSE,
collect_metrics(the_results)$mean[1])
TRUE

也许有更直接的方法,但我没有使用 tidymodels 来确定。

如果您在调用 caret::train 之前没有创建折叠:

set.seed(123)
cv_model1 <- train(
  form = Sale_Price ~ Gr_Liv_Area, 
  data = ames,
  method = "lm",
  trControl = trainControl(number = 10, method = "cv"))

你可以使用

cv_model1$control$index

cv_model1$control$indexOut

创建一个rsample对象

rsplit <- map2(cv_model1$control$index,
               cv_model1$control$indexOut,
               function(x,y) list(analysis = x, assessment = y))

然后按上述步骤进行。

splits <- lapply(rsplit, make_splits, data = ames)
splits <- manual_rset(splits, names(splits))