如何在 r 中的 rpart() 中关闭 k 折交叉验证

How to turn off k fold cross validation in rpart() in r

我有比特币时间序列,我使用 11 个技术指标作为特征,我想用回归树拟合数据。据我所知,r 中有 2 个函数可以创建回归树,即 rpart() 和 tree(),但这两个函数似乎都不合适。 rpart() 使用 k 折交叉验证来验证最优成本复杂度参数 cp,而在 tree() 中,无法指定 cp 的值。

我知道 cv.tree() 通过交叉验证寻找 cp 的最佳值,但同样,cv.tee() 使用 k 折交叉验证。由于我有时间序列,因此存在时间依赖性,所以我不想使用 k 折交叉验证,因为 k 折交叉验证会将数据随机分成 k 折,将模型拟合到 k-1 折并计算MSE 左边第 k 次折叠,然后我的时间序列序列显然被破坏了。

我找到了 rpart() 函数的一个参数,即 xval,它应该让我指定交叉验证的数量,但是当我查看 xval=0 时 rpart() 函数调用的输出时,交叉验证似乎没有关闭。下面你可以看到我的函数调用和输出:

tree.model= rpart(Close_5~ M+ DSMA+ DWMA+ DEMA+ CCI+ RSI+ DKD+ R+ FI+ DVI+ 
OBV, data= train.subset, method= "anova", control= 
rpart.control(cp=0.01,xval= 0, minbucket = 5))

> summary(tree.model)
Call:
rpart(formula = Close_5 ~ M + DSMA + DWMA + DEMA + CCI + RSI + 
DKD + R + FI + DVI + OBV, data = train.subset, method = "anova", 
control = rpart.control(cp = 0.01, xval = 0, minbucket = 5))
n= 590 

           CP nsplit rel error
1  0.35433076      0 1.0000000
2  0.10981049      1 0.6456692
3  0.06070669      2 0.5358587
4  0.04154720      3 0.4751521
5  0.02415633      5 0.3920576
6  0.02265346      6 0.3679013
7  0.02139752      8 0.3225944
8  0.02096500      9 0.3011969
9  0.02086543     10 0.2802319
10 0.01675277     11 0.2593665
11 0.01551861     13 0.2258609
12 0.01388126     14 0.2103423
13 0.01161287     15 0.1964610
14 0.01127722     16 0.1848482
15 0.01000000     18 0.1622937

似乎 rpart() 交叉验证了 15 个不同的 cp 值。如果这些值是用 k 折交叉验证测试的,那么我的时间序列的序列又会被破坏,我基本上不能使用这些结果。有谁知道如何有效地关闭 rpart() 中的交叉验证,或者如何改变 tree() 中的 cp 值?

更新:我听从了一位同事的建议并设置了 xval=1,但这似乎并没有解决问题。当 xval=1 here 时,您可以看到完整的函数输出。顺便说一句,parameters[j] 是参数向量的第 j 个元素。当我调用这个函数时,parameters[j]= 0.0009765625

非常感谢

为了证明 rpart() 是通过迭代 cp 的下降值而不是重采样来创建树节点,我们将使用 mlbench 包中的 Ozone 数据比较 OP 评论中讨论的 rpart()caret::train() 的结果。我们将按照 Support Vector Machines 的 CRAN 文档中的说明设置臭氧数据,它支持非线性回归并且与 rpart() 相当。

library(rpart)
library(caret)
data(Ozone, package = "mlbench")
# split into test and training
index <- 1:nrow(Ozone)
set.seed(01381708)
testIndex <- sample(index, trunc(length(index) / 3))
testset <- na.omit(Ozone[testIndex,-3])
trainset <- na.omit(Ozone[-testIndex,-3])


# rpart version
set.seed(95014) #reset seed to ensure sample is same as caret version
rpart.model <- rpart(V4 ~ .,data = trainset,xval=0)
# summary(rpart.model)
# calculate RMSE
rpart.pred <- predict(rpart.model, testset[,-3])
crossprod(rpart.pred - testset[,3]) / length(testIndex)

...以及 RMSE 计算的输出:

> crossprod(rpart.pred - testset[,3]) / length(testIndex)
         [,1]
[1,] 18.25507

接下来,我们将 运行 与 caret::train() 进行相同的分析,如对 OP 的评论中所建议的那样。

# caret version
set.seed(95014)
rpart.model <- caret::train(x = trainset[,-3],
                            y = trainset[,3],method = "rpart", trControl = trainControl(method = "none"), 
                            metric = "RMSE", tuneGrid = data.frame(cp=0.01), 
                            preProcess = c("center", "scale"), xval = 0, minbucket = 5)
# summary(rpart.model)
# demonstrate caret version did not do resampling
rpart.model
# calculate RMSE, which matches RMSE from rpart() 
rpart.pred <- predict(rpart.model, testset[,-3])
crossprod(rpart.pred - testset[,3]) / length(testIndex)

当我们打印 caret::train() 的模型输出时,它清楚地指出没有重新采样。

> rpart.model
CART 

135 samples
 11 predictor

Pre-processing: centered (9), scaled (9), ignore (2) 
Resampling: None

caret::train() 版本的 RMSE 与 rpart() 的 RMSE 匹配。

> # calculate RMSE, which matches RMSE from rpart() 
> rpart.pred <- predict(rpart.model, testset[,-3])
> crossprod(rpart.pred - testset[,3]) / length(testIndex)
         [,1]
[1,] 18.25507
> 

结论

首先,如上配置,caret::train()rpart()都没有重采样。但是,如果打印模型输出,则会看到 cp 的多个值用于通过这两种技术生成最终的 47 个节点的树。

插入符号的输出 summary(rpart.model)

          CP nsplit rel error
1 0.58951537      0 1.0000000
2 0.08544094      1 0.4104846
3 0.05237152      2 0.3250437
4 0.04686890      3 0.2726722
5 0.03603843      4 0.2258033
6 0.02651451      5 0.1897648
7 0.02194866      6 0.1632503
8 0.01000000      7 0.1413017

rpart 的输出 summary(rpart.model)

          CP nsplit rel error
1 0.58951537      0 1.0000000
2 0.08544094      1 0.4104846
3 0.05237152      2 0.3250437
4 0.04686890      3 0.2726722
5 0.03603843      4 0.2258033
6 0.02651451      5 0.1897648
7 0.02194866      6 0.1632503
8 0.01000000      7 0.1413017

其次,两个模型都通过将 monthday 变量作为独立变量来考虑时间值。 Ozone数据集中,V1为月变量,V2为日变量。所有数据都是在 1976 年收集的,因此数据集中没有年份变量,并且在 svm 小插图的原始分析中,星期几在分析之前被删除。

第三,当日期属性未用作模型中的特征时,要使用 rpart()svm() 等算法考虑其他基于时间的影响,必须将滞后效应作为特征包含在模型中模型,因为这些算法不直接考虑时间分量。 Ensemble Regression Trees for Time Series Predictions 是如何使用一系列滞后值对一组回归树执行此操作的示例。

在您的模型中,只需 xval=0 关闭交叉验证即可。

在你的输出中,你只有 CP NSPLIT REL ERROR,通过交叉验证你应该有 CP NSPLIT REL ERROR XERROR XSTD.

cp 只是您的“复杂度参数”(默认情况下 cp=0.01)从 1 到 0.01。

rel error 是您对数据集训练的预测误差/根节点的预期损失。

根据 cp

nsplit 与树大小相关的节点数。

看:https://cran.r-project.org/web/packages/rpart/vignettes/longintro.pdf