使用 R 中 'rpart' 包中的生存树来预测新的观察结果
Using a survival tree from the 'rpart' package in R to predict new observations
我正在尝试使用 R 中的 "rpart" 包构建生存树,我希望使用这棵树对其他观察结果进行预测。
我知道有很多涉及 rpart 和预测的 SO 问题;但是,我还没有找到任何解决(我认为)特定于将 rpart 与 "Surv" 对象一起使用的问题。
我的特殊问题涉及解释 "predict" 函数的结果。一个有用的例子:
library(rpart)
library(OIsurv)
# Make Data:
set.seed(4)
dat = data.frame(X1 = sample(x = c(1,2,3,4,5), size = 1000, replace=T))
dat$t = rexp(1000, rate=dat$X1)
dat$t = dat$t / max(dat$t)
dat$e = rbinom(n = 1000, size = 1, prob = 1-dat$t )
# Survival Fit:
sfit = survfit(Surv(t, event = e) ~ 1, data=dat)
plot(sfit)
# Tree Fit:
tfit = rpart(formula = Surv(t, event = e) ~ X1 , data = dat, control=rpart.control(minsplit=30, cp=0.01))
plot(tfit); text(tfit)
# Survival Fit, Broken by Node in Tree:
dat$node = as.factor(tfit$where)
plot( survfit(Surv(dat$t, event = dat$e)~dat$node) )
到目前为止一切顺利。我对这里发生的事情的理解是 rpart 试图将指数生存曲线拟合到我的数据子集。基于这种理解,我相信当我调用 predict(tfit)
时,对于每个观察,我都会得到一个与该观察的指数曲线参数相对应的数字。因此,例如,如果 predict(fit)[1]
是 .46,那么这意味着对于我的原始数据集中的第一次观察,曲线由方程 P(s) = exp(−λt)
给出,其中 λ=.46
.
这似乎正是我想要的。对于每个观察(或任何新观察),我可以获得该观察在给定时间点为 alive/dead 的预测概率。 (编辑:我意识到这可能是一个误解——这些曲线给出的不是 alive/dead 的概率,而是一个区间存活的概率。这不会改变下面描述的问题, 不过。)
但是,当我尝试使用指数公式时...
# Predict:
# an attempt to use the rates extracted from the tree to
# capture the survival curve formula in each tree node.
rates = unique(predict(tfit))
for (rate in rates) {
grid= seq(0,1,length.out = 100)
lines(x= grid, y= exp(-rate*(grid)), col=2)
}
我在这里所做的是以与生存树相同的方式拆分数据集,然后使用 survfit
为每个分区绘制非参数曲线。那就是黑线。我还绘制了对应于将 'rate' 参数插入(我认为的)生存指数公式的结果的线。
我知道非参数和参数拟合不一定相同,但这似乎不止于此:我似乎需要缩放 X 变量或其他东西。
基本上,我似乎不理解 rpart/survival 背后使用的公式。谁能帮我从 (1) rpart 模型到 (2) 任意观察的生存方程?
生存数据在内部按指数缩放,因此根节点中的预测率始终固定为 1.000
。 predict()
方法报告的预测总是相对于根节点中的存活率,即高出或低于某个因素。有关详细信息,请参阅 vignette("longintro", package = "rpart")
中的第 8.4 节。无论如何,您报告的 Kaplan-Meier 曲线与 rpart
小插图中报告的完全一致。
如果您想直接获取树中的 Kaplan-Meier 曲线图并获得预测的中位生存时间,您可以将 rpart
树强制转换为 constparty
树,如partykit
包:
library("partykit")
(tfit2 <- as.party(tfit))
## Model formula:
## Surv(t, event = e) ~ X1
##
## Fitted party:
## [1] root
## | [2] X1 < 2.5
## | | [3] X1 < 1.5: 0.192 (n = 213)
## | | [4] X1 >= 1.5: 0.082 (n = 213)
## | [5] X1 >= 2.5: 0.037 (n = 574)
##
## Number of inner nodes: 2
## Number of terminal nodes: 3
##
plot(tfit2)
打印输出显示了中位生存时间和相应的 Kaplan-Meier 曲线的可视化。两者也可以通过 predict()
方法将 type
参数分别设置为 "response"
和 "prob"
来获得。
predict(tfit2, type = "response")[1]
## 5
## 0.03671885
predict(tfit2, type = "prob")[[1]]
## Call: survfit(formula = y ~ 1, weights = w, subset = w > 0)
##
## records n.max n.start events median 0.95LCL 0.95UCL
## 574.0000 574.0000 574.0000 542.0000 0.0367 0.0323 0.0408
作为 rpart
生存树的替代方案,您还可以考虑基于 ctree()
中的条件推理(使用对数秩分数)的非参数生存树或使用一般mob()
来自 partykit
包的基础设施。
@Achim Zeileis 的回答很有帮助,但似乎没有回答@jwdink 的确切问题。我将其理解为“如果 RPart 树按最佳指数生存拟合进行拆分,那么这些拟合的 Lambdas 的绝对值是多少,因此我们可以使用这些指数生存函数进行预测”。 RPart 摘要确实显示了估计速率,但仅在假设整个总体的速率为 1 的情况下以相对术语显示。为了克服这一问题,可以拟合指数 survreg,从那里获取引用的 lambda,然后将 RPart 预测速率乘以该数字(请参阅下面的代码)。
也就是说,这不是如何从树中预测 RPart 中的存活率。我没有直接在 RPart 中找到生存预测函数,但是正如 Achim 上面指出的那样,partykit 使用 Kaplan-Meier 估计,即来自最终叶子的非参数生存。我认为在生存随机森林树中也是一样的,在最后的叶子中使用 K-M 曲线。
这个问题中的模拟数据使用指数分布,因此 K-M 和指数生存曲线在设计上是相似的,但是对于不同的模拟或现实生活分布,通过 RPart 树估计指数率并在最终中使用 K-M 曲线(同一棵树的)叶子会给出不同的存活率。
sfit = survfit(Surv(t, event = e) ~ 1, data=dat)
tfit = rpart(formula = Surv(t, event = e) ~ X1 , data = dat, control=rpart.control(minsplit=30, cp=0.01))
plot(tfit); text(tfit)
# Survival Fit, Broken by Node in Tree:
dat$node = as.factor(tfit$where)
table(dat$node)
s0 = survreg(Surv(t,e)~ 1, data = dat, dist = "exponential") #-0.6175
e0 = exp(-summary(s0)$coefficients[1]); e0 #1.854
rates = unique(predict(tfit))
#1) plot K-M curves by node (black):
plot( survfit(Surv(dat$t, event = dat$e)~dat$node) )
#2) plot exponential survival with rates = e0 * RPart rates (red):
for (rate in rates) {
grid= seq(0,1,length.out = 100)
lines(x= grid, y= exp(-e0*rate*(grid)), col=2)
}
#3) plot partykit survival curves based on RPart tree (green)
library(partykit)
tfit2 <- as.party(tfit)
col_n = 1
for (node in names(table(dat$node))){
predict_curve = predict(tfit2, newdata = dat[dat$node == node, ], type = "prob")
surv_esitmated = approxfun(predict_curve[[1]]$time, predict_curve[[1]]$surv)
lines(x= grid, y= surv_esitmated(grid), col = 2+col_n)
col_n=+1
}
我正在尝试使用 R 中的 "rpart" 包构建生存树,我希望使用这棵树对其他观察结果进行预测。
我知道有很多涉及 rpart 和预测的 SO 问题;但是,我还没有找到任何解决(我认为)特定于将 rpart 与 "Surv" 对象一起使用的问题。
我的特殊问题涉及解释 "predict" 函数的结果。一个有用的例子:
library(rpart)
library(OIsurv)
# Make Data:
set.seed(4)
dat = data.frame(X1 = sample(x = c(1,2,3,4,5), size = 1000, replace=T))
dat$t = rexp(1000, rate=dat$X1)
dat$t = dat$t / max(dat$t)
dat$e = rbinom(n = 1000, size = 1, prob = 1-dat$t )
# Survival Fit:
sfit = survfit(Surv(t, event = e) ~ 1, data=dat)
plot(sfit)
# Tree Fit:
tfit = rpart(formula = Surv(t, event = e) ~ X1 , data = dat, control=rpart.control(minsplit=30, cp=0.01))
plot(tfit); text(tfit)
# Survival Fit, Broken by Node in Tree:
dat$node = as.factor(tfit$where)
plot( survfit(Surv(dat$t, event = dat$e)~dat$node) )
到目前为止一切顺利。我对这里发生的事情的理解是 rpart 试图将指数生存曲线拟合到我的数据子集。基于这种理解,我相信当我调用 predict(tfit)
时,对于每个观察,我都会得到一个与该观察的指数曲线参数相对应的数字。因此,例如,如果 predict(fit)[1]
是 .46,那么这意味着对于我的原始数据集中的第一次观察,曲线由方程 P(s) = exp(−λt)
给出,其中 λ=.46
.
这似乎正是我想要的。对于每个观察(或任何新观察),我可以获得该观察在给定时间点为 alive/dead 的预测概率。 (编辑:我意识到这可能是一个误解——这些曲线给出的不是 alive/dead 的概率,而是一个区间存活的概率。这不会改变下面描述的问题, 不过。)
但是,当我尝试使用指数公式时...
# Predict:
# an attempt to use the rates extracted from the tree to
# capture the survival curve formula in each tree node.
rates = unique(predict(tfit))
for (rate in rates) {
grid= seq(0,1,length.out = 100)
lines(x= grid, y= exp(-rate*(grid)), col=2)
}
我在这里所做的是以与生存树相同的方式拆分数据集,然后使用 survfit
为每个分区绘制非参数曲线。那就是黑线。我还绘制了对应于将 'rate' 参数插入(我认为的)生存指数公式的结果的线。
我知道非参数和参数拟合不一定相同,但这似乎不止于此:我似乎需要缩放 X 变量或其他东西。
基本上,我似乎不理解 rpart/survival 背后使用的公式。谁能帮我从 (1) rpart 模型到 (2) 任意观察的生存方程?
生存数据在内部按指数缩放,因此根节点中的预测率始终固定为 1.000
。 predict()
方法报告的预测总是相对于根节点中的存活率,即高出或低于某个因素。有关详细信息,请参阅 vignette("longintro", package = "rpart")
中的第 8.4 节。无论如何,您报告的 Kaplan-Meier 曲线与 rpart
小插图中报告的完全一致。
如果您想直接获取树中的 Kaplan-Meier 曲线图并获得预测的中位生存时间,您可以将 rpart
树强制转换为 constparty
树,如partykit
包:
library("partykit")
(tfit2 <- as.party(tfit))
## Model formula:
## Surv(t, event = e) ~ X1
##
## Fitted party:
## [1] root
## | [2] X1 < 2.5
## | | [3] X1 < 1.5: 0.192 (n = 213)
## | | [4] X1 >= 1.5: 0.082 (n = 213)
## | [5] X1 >= 2.5: 0.037 (n = 574)
##
## Number of inner nodes: 2
## Number of terminal nodes: 3
##
plot(tfit2)
打印输出显示了中位生存时间和相应的 Kaplan-Meier 曲线的可视化。两者也可以通过 predict()
方法将 type
参数分别设置为 "response"
和 "prob"
来获得。
predict(tfit2, type = "response")[1]
## 5
## 0.03671885
predict(tfit2, type = "prob")[[1]]
## Call: survfit(formula = y ~ 1, weights = w, subset = w > 0)
##
## records n.max n.start events median 0.95LCL 0.95UCL
## 574.0000 574.0000 574.0000 542.0000 0.0367 0.0323 0.0408
作为 rpart
生存树的替代方案,您还可以考虑基于 ctree()
中的条件推理(使用对数秩分数)的非参数生存树或使用一般mob()
来自 partykit
包的基础设施。
@Achim Zeileis 的回答很有帮助,但似乎没有回答@jwdink 的确切问题。我将其理解为“如果 RPart 树按最佳指数生存拟合进行拆分,那么这些拟合的 Lambdas 的绝对值是多少,因此我们可以使用这些指数生存函数进行预测”。 RPart 摘要确实显示了估计速率,但仅在假设整个总体的速率为 1 的情况下以相对术语显示。为了克服这一问题,可以拟合指数 survreg,从那里获取引用的 lambda,然后将 RPart 预测速率乘以该数字(请参阅下面的代码)。
也就是说,这不是如何从树中预测 RPart 中的存活率。我没有直接在 RPart 中找到生存预测函数,但是正如 Achim 上面指出的那样,partykit 使用 Kaplan-Meier 估计,即来自最终叶子的非参数生存。我认为在生存随机森林树中也是一样的,在最后的叶子中使用 K-M 曲线。
这个问题中的模拟数据使用指数分布,因此 K-M 和指数生存曲线在设计上是相似的,但是对于不同的模拟或现实生活分布,通过 RPart 树估计指数率并在最终中使用 K-M 曲线(同一棵树的)叶子会给出不同的存活率。
sfit = survfit(Surv(t, event = e) ~ 1, data=dat)
tfit = rpart(formula = Surv(t, event = e) ~ X1 , data = dat, control=rpart.control(minsplit=30, cp=0.01))
plot(tfit); text(tfit)
# Survival Fit, Broken by Node in Tree:
dat$node = as.factor(tfit$where)
table(dat$node)
s0 = survreg(Surv(t,e)~ 1, data = dat, dist = "exponential") #-0.6175
e0 = exp(-summary(s0)$coefficients[1]); e0 #1.854
rates = unique(predict(tfit))
#1) plot K-M curves by node (black):
plot( survfit(Surv(dat$t, event = dat$e)~dat$node) )
#2) plot exponential survival with rates = e0 * RPart rates (red):
for (rate in rates) {
grid= seq(0,1,length.out = 100)
lines(x= grid, y= exp(-e0*rate*(grid)), col=2)
}
#3) plot partykit survival curves based on RPart tree (green)
library(partykit)
tfit2 <- as.party(tfit)
col_n = 1
for (node in names(table(dat$node))){
predict_curve = predict(tfit2, newdata = dat[dat$node == node, ], type = "prob")
surv_esitmated = approxfun(predict_curve[[1]]$time, predict_curve[[1]]$surv)
lines(x= grid, y= surv_esitmated(grid), col = 2+col_n)
col_n=+1
}