如何在 R 插入符号和示例中使用相同的交叉验证集
How can I use the same crossvalidation sets in R caret and rsamples
我正在尝试通过将 caret::train()
代码转换为 tidymodels
工作流程来学习 tidymodels
生态系统。我得到的差异我认为是 caret
与 rsample
中重采样算法的副产品。一位同事写了一篇要点,显示了具有相同种子的数据集的差异:https://gist.github.com/bradleyboehmke/7794b79a07afb443da11d930ff84bed7
您可以在此处看到简单模型中的细微差异(我认为我的编码是相同的):
library(caret)
library(tidyverse)
library(tidymodels)
data(ames)
set.seed(123)
(cv_model1 <- train(
form = Sale_Price ~ Gr_Liv_Area,
data = ames,
method = "lm",
trControl = trainControl(method="cv", number = 10)
))
对比
set.seed(123)
folds <- vfold_cv(ames, v = 10)
the_lm_model <-
linear_reg() %>%
set_engine("lm")
the_rec <-
recipe(Sale_Price ~ Gr_Liv_Area, data = ames)
the_workflow <-
workflow() %>%
add_recipe(the_rec) %>%
add_model(the_lm_model)
the_results <-
fit_resamples(the_workflow, folds)
collect_metrics(the_results)
在 tidymodel
工作流程(通常使用 rsample::vfold_cv()
创建)中是否有直接使用 caret
重采样(来自 caret::createFolds()
)的方法?我是希望如果我能弄清楚这个细节,我可以用新的生态系统复制复杂的旧代码(用于教学)。
编辑。感谢 Julia Silge 的评论。
函数
rsample2caret() and caret2rsample()
可用于在格式之间转换重采样对象。
下面的答案对于从任意格式转换为 rsample 很有用。
旧答案
这是一种将 caret::createFolds
的输出转换为 rsample
的方法
library(caret)
library(tidyverse)
library(tidymodels)
data(ames)
#create train folds
set.seed(123)
folds_train <- caret::createFolds(ames$Sale_Price, returnTrain = TRUE, k = 10)
#get test indexes
folds_test <- lapply(folds_train, function(x) setdiff(seq_along(ames$Sale_Price), x))
结合训练和测试索引创建分析列表和评估列表,如 manual_rset
中所述
rsplit <- map2(folds_train,
folds_test,
function(x,y) list(analysis = x, assessment = y))
splits <- lapply(rsplit, make_splits, data = ames)
splits <- manual_rset(splits, names(splits))
> splits
# Manual resampling
# A tibble: 10 x 2
splits id
<named list> <chr>
1 <split [2637/293]> Fold01
2 <split [2638/292]> Fold02
3 <split [2637/293]> Fold03
4 <split [2637/293]> Fold04
5 <split [2638/292]> Fold05
6 <split [2637/293]> Fold06
7 <split [2637/293]> Fold07
8 <split [2636/294]> Fold08
9 <split [2636/294]> Fold09
10 <split [2637/293]> Fold10
检查结果是否相同:
set.seed(123)
cv_model1 <- train(
form = Sale_Price ~ Gr_Liv_Area,
data = ames,
method = "lm",
trControl = trainControl(index= folds_train))
> cv_model1
Linear Regression
2930 samples
1 predictor
No pre-processing
Resampling: Bootstrapped (10 reps)
Summary of sample sizes: 2637, 2638, 2637, 2637, 2638, 2637, ...
Resampling results:
RMSE Rsquared MAE
56364.67 0.5066935 38575.21
Tuning parameter 'intercept' was held constant at a value of TRUE
the_lm_model <-
linear_reg() %>%
set_engine("lm")
the_rec <-
recipe(Sale_Price ~ Gr_Liv_Area, data = ames)
the_workflow <-
workflow() %>%
add_recipe(the_rec) %>%
add_model(the_lm_model)
set.seed(123)
the_results <-
fit_resamples(the_workflow, splits)
collect_metrics(the_results)
> collect_metrics(the_results)
# A tibble: 2 x 6
.metric .estimator mean n std_err .config
<chr> <chr> <dbl> <int> <dbl> <chr>
1 rmse standard 56365. 10 1782. Preprocessor1_Model1
2 rsq standard 0.507 10 0.0220 Preprocessor1_Model1
all.equal(
cv_model1$results$RMSE,
collect_metrics(the_results)$mean[1])
TRUE
也许有更直接的方法,但我没有使用 tidymodels 来确定。
如果您在调用 caret::train
之前没有创建折叠:
set.seed(123)
cv_model1 <- train(
form = Sale_Price ~ Gr_Liv_Area,
data = ames,
method = "lm",
trControl = trainControl(number = 10, method = "cv"))
你可以使用
cv_model1$control$index
cv_model1$control$indexOut
创建一个rsample
对象
rsplit <- map2(cv_model1$control$index,
cv_model1$control$indexOut,
function(x,y) list(analysis = x, assessment = y))
然后按上述步骤进行。
splits <- lapply(rsplit, make_splits, data = ames)
splits <- manual_rset(splits, names(splits))
我正在尝试通过将 caret::train()
代码转换为 tidymodels
工作流程来学习 tidymodels
生态系统。我得到的差异我认为是 caret
与 rsample
中重采样算法的副产品。一位同事写了一篇要点,显示了具有相同种子的数据集的差异:https://gist.github.com/bradleyboehmke/7794b79a07afb443da11d930ff84bed7
您可以在此处看到简单模型中的细微差异(我认为我的编码是相同的):
library(caret)
library(tidyverse)
library(tidymodels)
data(ames)
set.seed(123)
(cv_model1 <- train(
form = Sale_Price ~ Gr_Liv_Area,
data = ames,
method = "lm",
trControl = trainControl(method="cv", number = 10)
))
对比
set.seed(123)
folds <- vfold_cv(ames, v = 10)
the_lm_model <-
linear_reg() %>%
set_engine("lm")
the_rec <-
recipe(Sale_Price ~ Gr_Liv_Area, data = ames)
the_workflow <-
workflow() %>%
add_recipe(the_rec) %>%
add_model(the_lm_model)
the_results <-
fit_resamples(the_workflow, folds)
collect_metrics(the_results)
在 tidymodel
工作流程(通常使用 rsample::vfold_cv()
创建)中是否有直接使用 caret
重采样(来自 caret::createFolds()
)的方法?我是希望如果我能弄清楚这个细节,我可以用新的生态系统复制复杂的旧代码(用于教学)。
编辑。感谢 Julia Silge 的评论。
函数 rsample2caret() and caret2rsample()
可用于在格式之间转换重采样对象。
下面的答案对于从任意格式转换为 rsample 很有用。
旧答案
这是一种将 caret::createFolds
的输出转换为 rsample
library(caret)
library(tidyverse)
library(tidymodels)
data(ames)
#create train folds
set.seed(123)
folds_train <- caret::createFolds(ames$Sale_Price, returnTrain = TRUE, k = 10)
#get test indexes
folds_test <- lapply(folds_train, function(x) setdiff(seq_along(ames$Sale_Price), x))
结合训练和测试索引创建分析列表和评估列表,如 manual_rset
rsplit <- map2(folds_train,
folds_test,
function(x,y) list(analysis = x, assessment = y))
splits <- lapply(rsplit, make_splits, data = ames)
splits <- manual_rset(splits, names(splits))
> splits
# Manual resampling
# A tibble: 10 x 2
splits id
<named list> <chr>
1 <split [2637/293]> Fold01
2 <split [2638/292]> Fold02
3 <split [2637/293]> Fold03
4 <split [2637/293]> Fold04
5 <split [2638/292]> Fold05
6 <split [2637/293]> Fold06
7 <split [2637/293]> Fold07
8 <split [2636/294]> Fold08
9 <split [2636/294]> Fold09
10 <split [2637/293]> Fold10
检查结果是否相同:
set.seed(123)
cv_model1 <- train(
form = Sale_Price ~ Gr_Liv_Area,
data = ames,
method = "lm",
trControl = trainControl(index= folds_train))
> cv_model1
Linear Regression
2930 samples
1 predictor
No pre-processing
Resampling: Bootstrapped (10 reps)
Summary of sample sizes: 2637, 2638, 2637, 2637, 2638, 2637, ...
Resampling results:
RMSE Rsquared MAE
56364.67 0.5066935 38575.21
Tuning parameter 'intercept' was held constant at a value of TRUE
the_lm_model <-
linear_reg() %>%
set_engine("lm")
the_rec <-
recipe(Sale_Price ~ Gr_Liv_Area, data = ames)
the_workflow <-
workflow() %>%
add_recipe(the_rec) %>%
add_model(the_lm_model)
set.seed(123)
the_results <-
fit_resamples(the_workflow, splits)
collect_metrics(the_results)
> collect_metrics(the_results)
# A tibble: 2 x 6
.metric .estimator mean n std_err .config
<chr> <chr> <dbl> <int> <dbl> <chr>
1 rmse standard 56365. 10 1782. Preprocessor1_Model1
2 rsq standard 0.507 10 0.0220 Preprocessor1_Model1
all.equal(
cv_model1$results$RMSE,
collect_metrics(the_results)$mean[1])
TRUE
也许有更直接的方法,但我没有使用 tidymodels 来确定。
如果您在调用 caret::train
之前没有创建折叠:
set.seed(123)
cv_model1 <- train(
form = Sale_Price ~ Gr_Liv_Area,
data = ames,
method = "lm",
trControl = trainControl(number = 10, method = "cv"))
你可以使用
cv_model1$control$index
cv_model1$control$indexOut
创建一个rsample
对象
rsplit <- map2(cv_model1$control$index,
cv_model1$control$indexOut,
function(x,y) list(analysis = x, assessment = y))
然后按上述步骤进行。
splits <- lapply(rsplit, make_splits, data = ames)
splits <- manual_rset(splits, names(splits))