在 train 方法中,tuneGrid 和 trControl 之间的关系是什么?

In the train method what's the relationship between tuneGrid and trControl?

R 中训练已知 ML 模型的首选方法是使用 caret 包及其通用 train 方法。我的问题是 tuneGridtrControl 参数之间的关系是什么?因为它们无疑是相关的,我无法通过阅读文档来弄清楚它们的关系......例如:

library(caret)  
# train and choose best model using cross validation
df <- ... # contains input data
control <- trainControl(method = "cv", number = 10, p = .9, allowParallel = TRUE)
fit <- train(y ~ ., method = "knn", 
             data = df,
             tuneGrid = data.frame(k = seq(9, 71, 2)),
             trControl = control)

如果我运行上面的代码发生了什么?根据 trainControl 定义,每个包含 90% 数据的 10 个 CV 折叠如何与 k 的 32 个级别相结合?

更具体地说:

k近邻模型是否训练了32*10次?或者其他?

是的,你是对的。您将训练数据分成 10 组,比如 1..10。从第 1 组开始,您使用全部 2..10(90% 的训练数据)训练您的模型,并在第 1 组上测试它。这对 set2、set3 再次重复。总共 10 次,您有 32 个 k 值要测试,因此 32 * 10 = 320.

您也可以使用 trainControl 中的 returnResamp 函数提取此 cv 结果。我把它简化为下面的3倍和4个k值:

df <- mtcars
set.seed(100)
control <- trainControl(method = "cv", number = 3, p = .9,returnResamp="all")
fit <- train(mpg  ~ ., method = "knn", 
             data = mtcars,
             tuneGrid = data.frame(k = 2:5),
             trControl = control)

resample_results = fit$resample
resample_results
       RMSE  Rsquared      MAE k Resample
1  3.502321 0.7772086 2.483333 2    Fold1
2  3.807011 0.7636239 2.861111 3    Fold1
3  3.592665 0.8035741 2.697917 4    Fold1
4  3.682105 0.8486331 2.741667 5    Fold1
5  2.473611 0.8665093 1.995000 2    Fold2
6  2.673429 0.8128622 2.210000 3    Fold2
7  2.983224 0.7120910 2.645000 4    Fold2
8  2.998199 0.7207914 2.608000 5    Fold2
9  2.094039 0.9620830 1.610000 2    Fold3
10 2.551035 0.8717981 2.113333 3    Fold3
11 2.893192 0.8324555 2.482500 4    Fold3
12 2.806870 0.8700533 2.368333 5    Fold3

# we manually calculate the mean RMSE for each parameter
tapply(resample_results$RMSE,resample_results$k,mean)
       2        3        4        5 
2.689990 3.010492 3.156360 3.162392

# and we can see it corresponds to the final fit result
fit$results
k     RMSE  Rsquared      MAE    RMSESD RsquaredSD     MAESD
1 2 2.689990 0.8686003 2.029444 0.7286489 0.09245494 0.4376844
2 3 3.010492 0.8160947 2.394815 0.6925154 0.05415954 0.4067066
3 4 3.156360 0.7827069 2.608472 0.3805227 0.06283697 0.1122577
4 5 3.162392 0.8131593 2.572667 0.4601396 0.08070670 0.1891581