R插入符createFolds与createMultiFolds差异

R caret createFolds vs. createMultiFolds discrepancies

我使用 caret.

创建交叉验证折叠

我发现函数 createFolds 和 createMultiFolds 之间存在差异。在我看来,createFolds 没有替换,根据我的理解,这是正确的版本。 createMultiFolds 有两个缺陷,首先它使用替换,其次它每次折叠的观察结果比预期的多得多。

有谁知道为什么会出现这些差异,还是我必须以不同的方式指定它?最后我想用一个重复的交叉验证。

这是一个 MWE:

library(caret)
data(mtcars)

set.seed(123)
folds <- createMultiFolds(y = mtcars$am, k = 5, times = 5)

set.seed(123)
folds <- createFolds(mtcars$am, k = 5)

输出结果如下:

createMultiFolds(仅前 5 次折叠):

Fold1.Rep1  1  2  3  4  6  7  8  9 10  11  12  13  14  15  16  18  20  22  23  24  25  26  27  29  30  31
Fold2.Rep1  1  2  3  5  6  7  8  9 11  12  14  16  17  18  19  20  21  22  23  24  25  28  29  31  32
Fold3.Rep1  2  4  5  6  7  8  9 10 11  12  13  15  17  18  19  20  21  23  26  27  28  29  30  31  32
Fold4.Rep1  1  2  3  4  5  6  7 10 13  14  15  16  17  18  19  21  22  23  24  25  26  27  28  29  30  32
Fold5.Rep1  1  3  4  5  8  9 10 11 12  13  14  15  16  17  19  20  21  22  24  25  26  27  28  30  31  32

创建折叠:

Fold1  5 17 19 21 28 32
Fold2  4 10 13 15 26 27 30
Fold3  1  3 14 16 22 24 25
Fold4  8  9 11 12 20 31
Fold5  2  6  7 18 23 29

如果您检查 createMultiFolds 的源代码,您会发现它用 returnTrain = TRUE 调用了 createFolds。从文档中,

returnTrain: a logical. When true, the values returned are the sample
          positions corresponding to the data used during training.
          This argument only works in conjunction with ‘list = TRUE’

因此,如果你适当地修改createFolds,一切都很好:

> library(caret)
> data(mtcars)
> set.seed(123)
> multiFolds <- createMultiFolds(y = mtcars$am, k = 5, times = 2)
> set.seed(123)
> folds1 <- createFolds(mtcars$am, k = 5, returnTrain = TRUE)
> folds2 <- createFolds(mtcars$am, k = 5, returnTrain = TRUE)
> all(multiFolds$Fold1.Rep1 == folds1$Fold1)
[1] TRUE
> all(multiFolds$Fold2.Rep1 == folds1$Fold2)
[1] TRUE
> all(multiFolds$Fold3.Rep1 == folds1$Fold3)
[1] TRUE
> all(multiFolds$Fold4.Rep1 == folds1$Fold4)
[1] TRUE
> all(multiFolds$Fold5.Rep1 == folds1$Fold5)
[1] TRUE
> all(multiFolds$Fold1.Rep2 == folds2$Fold1)
[1] TRUE
> all(multiFolds$Fold2.Rep2 == folds2$Fold2)
[1] TRUE
> all(multiFolds$Fold3.Rep2 == folds2$Fold3)
[1] TRUE
> all(multiFolds$Fold4.Rep2 == folds2$Fold4)
[1] TRUE
> all(multiFolds$Fold5.Rep2 == folds2$Fold5)
[1] TRUE

createMultiFolds has two flaws, first it uses replacement [...]

你从哪里得到这个的?如果您谈论的是 1,则第一个是名称的一部分:Fold1.Rep1Fold2.Rep1、...、Fold{k}.Rep{times}.

如问题中所述,createFolds() 将数据拆分为 k 倍。但是,该函数的输出是一个观察索引列表,每个折叠 保留 ,而不是每个折叠中包含的行。我们可以通过创建所有折叠数据的 table 来看到这一点,如下所示。

set.seed(123)
folds <- createFolds(mtcars$am, k = 5)
table(unlist(folds))

...以及输出:

> table(unlist(folds))

 1  2  3  4  5  6  7  8  9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 
 1  1  1  1  1  1  1  1  1  1  1  1  1  1  1  1  1  1  1  1  1  1  1  1  1  1 
27 28 29 30 31 32 
 1  1  1  1  1  1 

如果我们将 returnTrain = TRUE 参数与 createFolds() 一起使用,它 return 是每个折叠中 包含 的观察索引,如图所示在另一个答案中。对于 k = 5,我们希望每个观察值用于 4 次折叠,并使用以下代码确认这一点。

set.seed(123)
folds <- createFolds(mtcars$am, k = 5, returnTrain = TRUE)
table(unlist(folds))

...以及输出:

> table(unlist(folds))

 1  2  3  4  5  6  7  8  9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 
 4  4  4  4  4  4  4  4  4  4  4  4  4  4  4  4  4  4  4  4  4  4  4  4  4  4 
27 28 29 30 31 32 
 4  4  4  4  4  4 

如答案中所述,设置 returnTrain = TRUE 会导致 createFolds() 到 return 与 createMultiFolds()times = 1 相同的输出。我们可以说明每个观察值用于 5 折中的 4 折,如下所示。

set.seed(123)
folds1 <- createMultiFolds(y = mtcars$am, k = 5, times = 1)
table(unlist(folds1))

...以及输出:

> table(unlist(folds1))

 1  2  3  4  5  6  7  8  9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 
 4  4  4  4  4  4  4  4  4  4  4  4  4  4  4  4  4  4  4  4  4  4  4  4  4  4 
27 28 29 30 31 32 
 4  4  4  4  4  4 

我们可以将foldsfolds的内容与lapply()all()的内容进行比较,如下所示。

# compare folds to folds1
lapply(1:5,function(x){
     all(folds1[[x]],folds[[x]])
})

[[1]]
[1] TRUE

[[2]]
[1] TRUE

[[3]]
[1] TRUE

[[4]]
[1] TRUE

[[5]]
[1] TRUE

如果我们设置 times = 2,我们希望每个观察值包含在 10 次中的 8 次中。

set.seed(123)
folds <- createMultiFolds(y = mtcars$am, k = 5, times = 2)
table(unlist(folds))

...以及输出:

> table(unlist(folds))

 1  2  3  4  5  6  7  8  9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 
 8  8  8  8  8  8  8  8  8  8  8  8  8  8  8  8  8  8  8  8  8  8  8  8  8  8 
27 28 29 30 31 32 
 8  8  8  8  8  8

结论: 在两个函数中 caret 使用抽样来确保每个观察值在 k 折叠中包含在保持组中 1 次times = 的每次重复,在传递给函数的因变量的每个值的观察值按比例分布在每个折叠的样本内和样本外分量的约束内。

mtcars这样的小数据集的情况下,算法要进行有效的分割并不容易,我们运行 tables比较的时候可以看到在样本/坚持与 mtcars$am.

set.seed(123)
folds <- createFolds(mtcars$am, k = 5)
table(unlist(folds))
lapply(folds,function(x){
     holdout <- rep(FALSE,nrow(mtcars))
     holdout[x] <- TRUE
     table(holdout,mtcars$am)
})

$Fold1
       
holdout  0  1
  FALSE 16 10
  TRUE   3  3

$Fold2
       
holdout  0  1
  FALSE 15 10
  TRUE   4  3

$Fold3
       
holdout  0  1
  FALSE 14 11
  TRUE   5  2

$Fold4
       
holdout  0  1
  FALSE 15 11
  TRUE   4  2

$Fold5
       
holdout  0  1
  FALSE 16 10
  TRUE   3  3

每个折叠在保留集中包含 6 或 7 个观察值,每个保留集中至少有 2 辆手动变速器汽车 (am = 1)。

使用默认参数,createFolds() returns 是保留观察的索引,而不是包含观察的索引。 createFolds(x,k,returnTrain=TRUE) 的行为与 createMultiFolds(x,k,times=1) 完全相同。