使用 R 中 'rpart' 包中的生存树来预测新的观察结果

Using a survival tree from the 'rpart' package in R to predict new observations

我正在尝试使用 R 中的 "rpart" 包构建生存树,我希望使用这棵树对其他观察结果进行预测。

我知道有很多涉及 rpart 和预测的 SO 问题;但是,我还没有找到任何解决(我认为)特定于将 rpart 与 "Surv" 对象一起使用的问题。

我的特殊问题涉及解释 "predict" 函数的结果。一个有用的例子:

library(rpart)
library(OIsurv)

# Make Data:
set.seed(4)
dat = data.frame(X1 = sample(x = c(1,2,3,4,5), size = 1000, replace=T))
dat$t = rexp(1000, rate=dat$X1)
dat$t = dat$t / max(dat$t)
dat$e = rbinom(n = 1000, size = 1, prob = 1-dat$t )

# Survival Fit:
sfit = survfit(Surv(t, event = e) ~ 1, data=dat)
plot(sfit)

# Tree Fit:
tfit = rpart(formula = Surv(t, event = e) ~ X1 , data = dat, control=rpart.control(minsplit=30, cp=0.01))
plot(tfit); text(tfit)

# Survival Fit, Broken by Node in Tree:
dat$node = as.factor(tfit$where)
plot( survfit(Surv(dat$t, event = dat$e)~dat$node) )

到目前为止一切顺利。我对这里发生的事情的理解是 rpart 试图将指数生存曲线拟合到我的数据子集。基于这种理解,我相信当我调用 predict(tfit) 时,对于每个观察,我都会得到一个与该观察的指数曲线参数相对应的数字。因此,例如,如果 predict(fit)[1] 是 .46,那么这意味着对于我的原始数据集中的第一次观察,曲线由方程 P(s) = exp(−λt) 给出,其中 λ=.46.

这似乎正是我想要的。对于每个观察(或任何新观察),我可以获得该观察在给定时间点为 alive/dead 的预测概率。 (编辑:我意识到这可能是一个误解——这些曲线给出的不是 alive/dead 的概率,而是一个区间存活的概率。这不会改变下面描述的问题, 不过。)

但是,当我尝试使用指数公式时...

# Predict:
# an attempt to use the rates extracted from the tree to
# capture the survival curve formula in each tree node.
rates = unique(predict(tfit))
for (rate in rates) {
  grid= seq(0,1,length.out = 100)
  lines(x= grid, y= exp(-rate*(grid)), col=2)
}

我在这里所做的是以与生存树相同的方式拆分数据集,然后使用 survfit 为每个分区绘制非参数曲线。那就是黑线。我还绘制了对应于将 'rate' 参数插入(我认为的)生存指数公式的结果的线。

我知道非参数和参数拟合不一定相同,但这似乎不止于此:我似乎需要缩放 X 变量或其他东西。

基本上,我似乎不理解 rpart/survival 背后使用的公式。谁能帮我从 (1) rpart 模型到 (2) 任意观察的生存方程?

生存数据在内部按指数缩放,因此根节点中的预测率始终固定为 1.000predict() 方法报告的预测总是相对于根节点中的存活率,即高出或低于某个因素。有关详细信息,请参阅 vignette("longintro", package = "rpart") 中的第 8.4 节。无论如何,您报告的 Kaplan-Meier 曲线与 rpart 小插图中报告的完全一致。

如果您想直接获取树中的 Kaplan-Meier 曲线图并获得预测的中位生存时间,您可以将 rpart 树强制转换为 constparty 树,如partykit 包:

library("partykit")
(tfit2 <- as.party(tfit))
## Model formula:
## Surv(t, event = e) ~ X1
## 
## Fitted party:
## [1] root
## |   [2] X1 < 2.5
## |   |   [3] X1 < 1.5: 0.192 (n = 213)
## |   |   [4] X1 >= 1.5: 0.082 (n = 213)
## |   [5] X1 >= 2.5: 0.037 (n = 574)
## 
## Number of inner nodes:    2
## Number of terminal nodes: 3
##
plot(tfit2)

打印输出显示了中位生存时间和相应的 Kaplan-Meier 曲线的可视化。两者也可以通过 predict() 方法将 type 参数分别设置为 "response""prob" 来获得。

predict(tfit2, type = "response")[1]
##          5 
## 0.03671885 
predict(tfit2, type = "prob")[[1]]
## Call: survfit(formula = y ~ 1, weights = w, subset = w > 0)
## 
##  records    n.max  n.start   events   median  0.95LCL  0.95UCL 
## 574.0000 574.0000 574.0000 542.0000   0.0367   0.0323   0.0408 

作为 rpart 生存树的替代方案,您还可以考虑基于 ctree() 中的条件推理(使用对数秩分数)的非参数生存树或使用一般mob() 来自 partykit 包的基础设施。

@Achim Zeileis 的回答很有帮助,但似乎没有回答@jwdink 的确切问题。我将其理解为“如果 RPart 树按最佳指数生存拟合进行拆分,那么这些拟合的 Lambdas 的绝对值是多少,因此我们可以使用这些指数生存函数进行预测”。 RPart 摘要确实显示了估计速率,但仅在假设整个总体的速率为 1 的情况下以相对术语显示。为了克服这一问题,可以拟合指数 survreg,从那里获取引用的 lambda,然后将 RPart 预测速率乘以该数字(请参阅下面的代码)。

也就是说,这不是如何从树中预测 RPart 中的存活率。我没有直接在 RPart 中找到生存预测函数,但是正如 Achim 上面指出的那样,partykit 使用 Kaplan-Meier 估计,即来自最终叶子的非参数生存。我认为在生存随机森林树中也是一样的,在最后的叶子中使用 K-M 曲线。

这个问题中的模拟数据使用指数分布,因此 K-M 和指数生存曲线在设计上是相似的,但是对于不同的模拟或现实生活分布,通过 RPart 树估计指数率并在最终中使用 K-M 曲线(同一棵树的)叶子会给出不同的存活率。

sfit = survfit(Surv(t, event = e) ~ 1, data=dat)
tfit = rpart(formula = Surv(t, event = e) ~ X1 , data = dat, control=rpart.control(minsplit=30, cp=0.01))
plot(tfit); text(tfit)

# Survival Fit, Broken by Node in Tree:
dat$node = as.factor(tfit$where)
table(dat$node)
s0 = survreg(Surv(t,e)~ 1, data =  dat, dist = "exponential") #-0.6175
e0 = exp(-summary(s0)$coefficients[1]); e0 #1.854
rates = unique(predict(tfit))
#1) plot K-M curves by node (black):
plot( survfit(Surv(dat$t, event = dat$e)~dat$node) )

#2) plot exponential survival with rates = e0 * RPart rates (red):
for (rate in rates) {
  grid= seq(0,1,length.out = 100)
  lines(x= grid, y= exp(-e0*rate*(grid)), col=2)
}
#3) plot partykit survival curves based on RPart tree (green)
library(partykit)
tfit2 <- as.party(tfit)
col_n = 1
for (node in names(table(dat$node))){
  predict_curve = predict(tfit2, newdata = dat[dat$node == node, ], type = "prob")  
  surv_esitmated = approxfun(predict_curve[[1]]$time, predict_curve[[1]]$surv)
  lines(x= grid, y= surv_esitmated(grid), col = 2+col_n)
  col_n=+1
}