Rpart vs. caret rpart "Error: There were missing values in resampled performance measures"

Rpart vs. caret rpart "Error: There were missing values in resampled performance measures"

我使用了 caret 包并尝试使用 rpart 方法。有趣的是,我可以用通用的 rpart 包拟合一个模型,但是一旦我使用 caret 包,它就不再起作用了。更让我困惑的是,我在各种网站上看到使用了 caret 中的 rpart,例如对于波士顿数据。

我很困惑是我的模型实现不正确还是我在这里漏掉了一点。 对于 rpart_tree2(下方),我收到以下错误消息:“在 nominalTrainWorkflow(x = x, y = y, wts = weights, info = trainInfo, : There were missing values in resampled performance measures.”

我知道我也可以指定例如repeatedcv,但这对错误消息没有影响。

下面是一个MWE(我尽量保持简单):

library(caret)
library(rpart)

data("Boston")

index <- sample(nrow(Boston),nrow(Boston)*0.75)
Boston.train <- Boston[index,]
Boston.test <- Boston[-index,]

rpart_tree1 <- rpart(medv ~ ., data = Boston.train)

rpart_tree2 <- train(medv ~., data = Boston.train, method = "rpart")

警告不是问题。

在某些重采样中具有较大的 cp 值,生成的树没有分裂。当树没有分裂时,预测值是训练结果值的平均值。由于预测值没有方差,cor 函数会引发警告,结果为 NA。此函数用于计算 RSquared - 因此对于这些重采样 RSquared 是 NA - 换句话说它丢失了 - 警告暗示。

示例:

library(caret)
library(rpart)
library(MASS)
data(Boston)

set.seed(1)
index <- sample(nrow(Boston),nrow(Boston)*0.75)
Boston.train <- Boston[index,]
Boston.test <- Boston[-index,]

较低 cp 不产生警告:

rpart_tree2 <- train(medv ~., data = Boston.train, method = "rpart",
                     tuneGrid = data.frame(cp = c(0.01, 0.05, 0.1)))

当我指定更高的 cp 和特定的种子时:

set.seed(111)
rpart_tree3 <- train(medv ~., data = Boston.train, method = "rpart",
                     tuneGrid = data.frame(cp = c(0.4)),
                     trControl = trainControl(savePredictions = TRUE))

Warning message:
In nominalTrainWorkflow(x = x, y = y, wts = weights, info = trainInfo,  :
  There were missing values in resampled performance measures.

检查问题:

rpart_tree3$resample
        RMSE  Rsquared      MAE   Resample
1   7.530482 0.4361392 5.708437 Resample01
2   7.334995 0.2350619 5.392867 Resample02
3   7.178178 0.3971089 5.511530 Resample03
4   6.369189 0.2798907 4.851146 Resample04
5   7.550175 0.3344412 5.566677 Resample05
6   7.019099 0.4270561 5.160572 Resample06
7   7.197384 0.4530680 5.665177 Resample07
8   7.206760 0.3447690 5.290300 Resample08
9   7.408748 0.4553087 5.513998 Resample09
10  7.241468 0.4119979 5.452725 Resample10
11  7.562511 0.3967082 5.768643 Resample11
12  7.347378 0.3861702 5.225532 Resample12
13  7.124039 0.4039857 5.599800 Resample13
14  7.151013 0.3301835 5.490676 Resample14
15  6.518536 0.3835073 4.938662 Resample15
16 10.008008        NA 7.174290 Resample16
17  7.018742 0.4431380 5.379823 Resample17
18  7.454669 0.3888220 6.000062 Resample18
19  6.745457 0.3772237 5.175481 Resample19
20  6.864304 0.4179276 5.089924 Resample20
21  7.238874 0.2378432 5.234752 Resample21
22  7.581736 0.3707839 5.543641 Resample22
23  7.236317 0.3431725 5.278693 Resample23
24  7.232241 0.4196955 5.518907 Resample24
25  6.641846 0.3664023 4.683834 Resample25

我们可以看到问题发生在Resample16

library(tidyverse)
rpart_tree3$pred %>%
  filter(Resample == "Resample16") -> for_cor
head(for_cor)
      pred  obs rowIndex  cp   Resample
1 21.87018 15.6        1 0.4 Resample16
2 21.87018 22.3        3 0.4 Resample16
3 21.87018 13.4        6 0.4 Resample16
4 21.87018 12.7       10 0.4 Resample16
5 21.87018 18.6       11 0.4 Resample16
6 21.87018 19.0       13 0.4 Resample16

我们可以看到 pred 对于 Resample16

的每一行都是相同的
 cor(for_cor$pred, for_cor$obs, use = "pairwise.complete.obs")
[1] NA
Warning message:
In cor(for_cor$pred, for_cor$obs, use = "pairwise.complete.obs") :
  the standard deviation is zero

要查看插入符号中的 RSquared 是如何计算的,请查看 postResample 的来源。基本上 cor(pred, obs)^2