在 train 方法中,tuneGrid 和 trControl 之间的关系是什么?
In the train method what's the relationship between tuneGrid and trControl?
R 中训练已知 ML 模型的首选方法是使用 caret
包及其通用 train
方法。我的问题是 tuneGrid
和 trControl
参数之间的关系是什么?因为它们无疑是相关的,我无法通过阅读文档来弄清楚它们的关系......例如:
library(caret)
# train and choose best model using cross validation
df <- ... # contains input data
control <- trainControl(method = "cv", number = 10, p = .9, allowParallel = TRUE)
fit <- train(y ~ ., method = "knn",
data = df,
tuneGrid = data.frame(k = seq(9, 71, 2)),
trControl = control)
如果我运行上面的代码发生了什么?根据 trainControl
定义,每个包含 90% 数据的 10 个 CV 折叠如何与 k
的 32 个级别相结合?
更具体地说:
- 参数
k
. 我有 32 个级别
- 我也有 10 次 CV 折叠。
k近邻模型是否训练了32*10次?或者其他?
是的,你是对的。您将训练数据分成 10 组,比如 1..10。从第 1 组开始,您使用全部 2..10(90% 的训练数据)训练您的模型,并在第 1 组上测试它。这对 set2、set3 再次重复。总共 10 次,您有 32 个 k 值要测试,因此 32 * 10 = 320.
您也可以使用 trainControl 中的 returnResamp 函数提取此 cv 结果。我把它简化为下面的3倍和4个k值:
df <- mtcars
set.seed(100)
control <- trainControl(method = "cv", number = 3, p = .9,returnResamp="all")
fit <- train(mpg ~ ., method = "knn",
data = mtcars,
tuneGrid = data.frame(k = 2:5),
trControl = control)
resample_results = fit$resample
resample_results
RMSE Rsquared MAE k Resample
1 3.502321 0.7772086 2.483333 2 Fold1
2 3.807011 0.7636239 2.861111 3 Fold1
3 3.592665 0.8035741 2.697917 4 Fold1
4 3.682105 0.8486331 2.741667 5 Fold1
5 2.473611 0.8665093 1.995000 2 Fold2
6 2.673429 0.8128622 2.210000 3 Fold2
7 2.983224 0.7120910 2.645000 4 Fold2
8 2.998199 0.7207914 2.608000 5 Fold2
9 2.094039 0.9620830 1.610000 2 Fold3
10 2.551035 0.8717981 2.113333 3 Fold3
11 2.893192 0.8324555 2.482500 4 Fold3
12 2.806870 0.8700533 2.368333 5 Fold3
# we manually calculate the mean RMSE for each parameter
tapply(resample_results$RMSE,resample_results$k,mean)
2 3 4 5
2.689990 3.010492 3.156360 3.162392
# and we can see it corresponds to the final fit result
fit$results
k RMSE Rsquared MAE RMSESD RsquaredSD MAESD
1 2 2.689990 0.8686003 2.029444 0.7286489 0.09245494 0.4376844
2 3 3.010492 0.8160947 2.394815 0.6925154 0.05415954 0.4067066
3 4 3.156360 0.7827069 2.608472 0.3805227 0.06283697 0.1122577
4 5 3.162392 0.8131593 2.572667 0.4601396 0.08070670 0.1891581
R 中训练已知 ML 模型的首选方法是使用 caret
包及其通用 train
方法。我的问题是 tuneGrid
和 trControl
参数之间的关系是什么?因为它们无疑是相关的,我无法通过阅读文档来弄清楚它们的关系......例如:
library(caret)
# train and choose best model using cross validation
df <- ... # contains input data
control <- trainControl(method = "cv", number = 10, p = .9, allowParallel = TRUE)
fit <- train(y ~ ., method = "knn",
data = df,
tuneGrid = data.frame(k = seq(9, 71, 2)),
trControl = control)
如果我运行上面的代码发生了什么?根据 trainControl
定义,每个包含 90% 数据的 10 个 CV 折叠如何与 k
的 32 个级别相结合?
更具体地说:
- 参数
k
. 我有 32 个级别
- 我也有 10 次 CV 折叠。
k近邻模型是否训练了32*10次?或者其他?
是的,你是对的。您将训练数据分成 10 组,比如 1..10。从第 1 组开始,您使用全部 2..10(90% 的训练数据)训练您的模型,并在第 1 组上测试它。这对 set2、set3 再次重复。总共 10 次,您有 32 个 k 值要测试,因此 32 * 10 = 320.
您也可以使用 trainControl 中的 returnResamp 函数提取此 cv 结果。我把它简化为下面的3倍和4个k值:
df <- mtcars
set.seed(100)
control <- trainControl(method = "cv", number = 3, p = .9,returnResamp="all")
fit <- train(mpg ~ ., method = "knn",
data = mtcars,
tuneGrid = data.frame(k = 2:5),
trControl = control)
resample_results = fit$resample
resample_results
RMSE Rsquared MAE k Resample
1 3.502321 0.7772086 2.483333 2 Fold1
2 3.807011 0.7636239 2.861111 3 Fold1
3 3.592665 0.8035741 2.697917 4 Fold1
4 3.682105 0.8486331 2.741667 5 Fold1
5 2.473611 0.8665093 1.995000 2 Fold2
6 2.673429 0.8128622 2.210000 3 Fold2
7 2.983224 0.7120910 2.645000 4 Fold2
8 2.998199 0.7207914 2.608000 5 Fold2
9 2.094039 0.9620830 1.610000 2 Fold3
10 2.551035 0.8717981 2.113333 3 Fold3
11 2.893192 0.8324555 2.482500 4 Fold3
12 2.806870 0.8700533 2.368333 5 Fold3
# we manually calculate the mean RMSE for each parameter
tapply(resample_results$RMSE,resample_results$k,mean)
2 3 4 5
2.689990 3.010492 3.156360 3.162392
# and we can see it corresponds to the final fit result
fit$results
k RMSE Rsquared MAE RMSESD RsquaredSD MAESD
1 2 2.689990 0.8686003 2.029444 0.7286489 0.09245494 0.4376844
2 3 3.010492 0.8160947 2.394815 0.6925154 0.05415954 0.4067066
3 4 3.156360 0.7827069 2.608472 0.3805227 0.06283697 0.1122577
4 5 3.162392 0.8131593 2.572667 0.4601396 0.08070670 0.1891581