R从partykit决策树中提取终端节点信息

R extract terminal node info from partykit decision tree

我已经创建了一个constparty 决策树(自定义拆分规则)并打印出树结果。结果如下所示:

Fitted party:
[1] root
|   [2] value.a < 1651: 0.067 (n = 1419, err = 88.6)
|   [3] value.a >= 1651: 0.571 (n = 7, err = 1.7)

我正在尝试提取终端节点信息 (yval:0.067 和 0.571; 每个节点上的 n:1419 和 7; 和错误:88.6 和 1.7) 并将它们放入列表中,同时具有相应的节点 ID(节点 ID 2 和 3),以便我以后可以利用这些信息。

我研究 partykit 函数有一段时间了,但找不到可以帮助我提取刚刚列出的那些信息的函数。

有人可以帮我吗?谢谢!

像往常一样,有几种方法可以获得您正在寻找的信息。提取存储在特定 node 中的 info 的技术方法是使用 nodeapply(object, ids, info_node) 其中 info_node returns 存储在相应节点中的信息列表。

然而,在 constparty 个对象的终端节点中没有存储任何内容。相反,拟合节点的响应的整个分布被存储并且可以通过 fitted(object) 提取。这包含一个数据框,其中包含观察到的 (response)(fitted) 节点和观察到的 (weights)(如果有)。然后你可以很容易地使用 tapply()aggregate() 或类似的东西来计算节点方式等

或者,您可以将 constparty 对象转换为 simpleparty 对象,该对象将打印信息存储在节点中并提取它。

这两种策略的一个有效示例是 cars 数据的简单回归树:

library("partykit")
data("cars", package = "datasets")
ct <- ctree(dist ~ speed, data = cars)

然后您可以通过

轻松计算节点 means
with(fitted(ct), tapply(`(response)`, `(fitted)`, mean))
##        3        4        5 
## 18.20000 39.75000 65.26316 

当然,您可以将 mean 替换为您感兴趣的任何其他摘要统计信息。

simplepartynodeapply() 可以通过以下方式获得:

nodeapply(as.simpleparty(ct), ids = nodeids(ct, terminal = TRUE), info_node)
## $`3`
## $`3`$prediction
## [1] 18.2
## 
## $`3`$n
##  n 
## 15 
## 
## $`3`$error
## [1] 1176.4
## 
## $`3`$distribution
## NULL
## 
## $`3`$p.value
## NULL
## 
## 
## $`4`
## $`4`$prediction
## [1] 39.75
## ...