在 rpart 的节点中获取观察结果(即:CART)
Getting the observations in a rpart's node (i.e.: CART)
我想检查到达 rpart 决策树中某个节点的所有观察结果。例如,在下面的代码中:
fit <- rpart(Kyphosis ~ Age + Start, data = kyphosis)
fit
n= 81
node), split, n, loss, yval, (yprob)
* denotes terminal node
1) root 81 17 absent (0.79012346 0.20987654)
2) Start>=8.5 62 6 absent (0.90322581 0.09677419)
4) Start>=14.5 29 0 absent (1.00000000 0.00000000) *
5) Start< 14.5 33 6 absent (0.81818182 0.18181818)
10) Age< 55 12 0 absent (1.00000000 0.00000000) *
11) Age>=55 21 6 absent (0.71428571 0.28571429)
22) Age>=111 14 2 absent (0.85714286 0.14285714) *
23) Age< 111 7 3 present (0.42857143 0.57142857) *
3) Start< 8.5 19 8 present (0.42105263 0.57894737) *
我想查看节点 (5) 中的所有观察结果(即:Start>=8.5 和 Start<14.5 的 33 个观察结果)。显然我可以手动找到它们。但我想要一些功能,比如(比如)"get_node_date"。为此,我可以 运行 get_node_date(5) - 并获得相关的观察结果。
关于如何解决这个问题有什么建议吗?
rpart returns rpart.object 包含您需要的信息的元素:
require(rpart)
fit2 <- rpart(Kyphosis ~ Age + Start, data = kyphosis)
fit2
get_node_date <-function(nodeId,fit)
{
fit$frame[toString(nodeId),"n"]
}
for (i in c(1,2,4,5,10,11,22,23,3) )
cat(get_node_date(i,fit2),"\n")
似乎没有这样的功能可以从特定节点提取观察结果。我会按如下方式解决:首先确定哪个rule/s is/are用于您感兴趣的节点。您可以使用path.rpart
。然后,您可以一个接一个地应用 rule/s 来提取观察结果。
这种方法作为一个函数:
get_node_date <- function(tree = fit, node = 5){
rule <- path.rpart(tree, node)
rule_2 <- sapply(rule[[1]][-1], function(x) strsplit(x, '(?<=[><=])(?=[^><=])|(?<=[^><=])(?=[><=])', perl = TRUE))
ind <- apply(do.call(cbind, lapply(rule_2, function(x) eval(call(x[2], kyphosis[,x[1]], as.numeric(x[3]))))), 1, all)
kyphosis[ind,]
}
对于节点 5,您得到:
get_node_date()
node number: 5
root
Start>=8.5
Start< 14.5
Kyphosis Age Number Start
2 absent 158 3 14
10 present 59 6 12
11 present 82 5 14
14 absent 1 4 12
18 absent 175 5 13
20 absent 27 4 9
23 present 96 3 12
26 absent 9 5 13
28 absent 100 3 14
32 absent 125 2 11
33 absent 130 5 13
35 absent 140 5 11
37 absent 1 3 9
39 absent 20 6 9
40 present 91 5 12
42 absent 35 3 13
46 present 139 3 10
48 absent 131 5 13
50 absent 177 2 14
51 absent 68 5 10
57 absent 2 3 13
59 absent 51 7 9
60 absent 102 3 13
66 absent 17 4 10
68 absent 159 4 13
69 absent 18 4 11
71 absent 158 5 14
72 absent 127 4 12
74 absent 206 4 10
77 present 157 3 13
78 absent 26 7 13
79 absent 120 2 13
81 absent 36 4 13
partykit
软件包也为此提供了固定解决方案。您只需要将 rpart
对象转换为 party
class 对象,以便使用其统一接口来处理树。然后就可以使用data_party()
函数了。
使用问题中的 fit
并加载 library("partykit")
您可以首先将 rpart
树强制为 party
:
pfit <- as.party(fit)
plot(pfit)
以您想要的方式提取数据只有两个小麻烦:(1) 原始拟合中的 model.frame()
总是在强制转换中丢失,需要手动重新附加。 (2) 节点采用不同的编号方案。您现在需要节点 4(而不是 5)。
pfit$data <- model.frame(fit)
data4 <- data_party(pfit, 4)
dim(data4)
## [1] 33 5
head(data4)
## Kyphosis Age Start (fitted) (response)
## 2 absent 158 14 7 absent
## 10 present 59 12 8 present
## 11 present 82 14 8 present
## 14 absent 1 12 5 absent
## 18 absent 175 13 7 absent
## 20 absent 27 9 5 absent
另一种方法是从节点 4 开始对子树进行子集化,然后从中获取数据:
pfit4 <- pfit[4]
plot(pfit4)
然后data_party(pfit4)
给你和上面data4
一样的结果。 pfit4$data
为您提供没有 (fitted)
节点和预测的 (response)
.
的数据
还有另一种方法,它通过查找任何特定节点的所有终端节点并返回调用中使用的数据子集来实现。
fit <- rpart(Kyphosis ~ Age + Start, data = kyphosis)
head(subset.rpart(fit, 5))
# Kyphosis Age Number Start
# 2 absent 158 3 14
# 10 present 59 6 12
# 11 present 82 5 14
# 14 absent 1 4 12
# 18 absent 175 5 13
# 20 absent 27 4 9
subset.rpart <- function(tree, node = 1L) {
data <- eval(tree$call$data, parent.frame(1L))
wh <- sapply(as.integer(rownames(tree$frame)), parent)
wh <- unique(unlist(wh[sapply(wh, function(x) node %in% x)]))
data[rownames(tree$frame)[tree$where] %in% wh[wh >= node], ]
}
parent <- function(x) {
if (x[1] != 1)
c(Recall(if (x %% 2 == 0L) x / 2 else (x - 1) / 2), x) else x
}
比原版 post 晚了两年,但可能对其他人有用。 rpart 中训练观察的节点分配可以从 $where
获得:
fit <- rpart(Kyphosis ~ Age + Start, data = kyphosis)
fit$where
作为函数:
get_node <- function(rpart.object=fit, data=kyphosis, node.number=5) {
data[which(fit$where == node.number),]
}
get_node()
这仅适用于训练观察,不适用于新观察。
我想检查到达 rpart 决策树中某个节点的所有观察结果。例如,在下面的代码中:
fit <- rpart(Kyphosis ~ Age + Start, data = kyphosis)
fit
n= 81
node), split, n, loss, yval, (yprob)
* denotes terminal node
1) root 81 17 absent (0.79012346 0.20987654)
2) Start>=8.5 62 6 absent (0.90322581 0.09677419)
4) Start>=14.5 29 0 absent (1.00000000 0.00000000) *
5) Start< 14.5 33 6 absent (0.81818182 0.18181818)
10) Age< 55 12 0 absent (1.00000000 0.00000000) *
11) Age>=55 21 6 absent (0.71428571 0.28571429)
22) Age>=111 14 2 absent (0.85714286 0.14285714) *
23) Age< 111 7 3 present (0.42857143 0.57142857) *
3) Start< 8.5 19 8 present (0.42105263 0.57894737) *
我想查看节点 (5) 中的所有观察结果(即:Start>=8.5 和 Start<14.5 的 33 个观察结果)。显然我可以手动找到它们。但我想要一些功能,比如(比如)"get_node_date"。为此,我可以 运行 get_node_date(5) - 并获得相关的观察结果。
关于如何解决这个问题有什么建议吗?
rpart returns rpart.object 包含您需要的信息的元素:
require(rpart)
fit2 <- rpart(Kyphosis ~ Age + Start, data = kyphosis)
fit2
get_node_date <-function(nodeId,fit)
{
fit$frame[toString(nodeId),"n"]
}
for (i in c(1,2,4,5,10,11,22,23,3) )
cat(get_node_date(i,fit2),"\n")
似乎没有这样的功能可以从特定节点提取观察结果。我会按如下方式解决:首先确定哪个rule/s is/are用于您感兴趣的节点。您可以使用path.rpart
。然后,您可以一个接一个地应用 rule/s 来提取观察结果。
这种方法作为一个函数:
get_node_date <- function(tree = fit, node = 5){
rule <- path.rpart(tree, node)
rule_2 <- sapply(rule[[1]][-1], function(x) strsplit(x, '(?<=[><=])(?=[^><=])|(?<=[^><=])(?=[><=])', perl = TRUE))
ind <- apply(do.call(cbind, lapply(rule_2, function(x) eval(call(x[2], kyphosis[,x[1]], as.numeric(x[3]))))), 1, all)
kyphosis[ind,]
}
对于节点 5,您得到:
get_node_date()
node number: 5
root
Start>=8.5
Start< 14.5
Kyphosis Age Number Start
2 absent 158 3 14
10 present 59 6 12
11 present 82 5 14
14 absent 1 4 12
18 absent 175 5 13
20 absent 27 4 9
23 present 96 3 12
26 absent 9 5 13
28 absent 100 3 14
32 absent 125 2 11
33 absent 130 5 13
35 absent 140 5 11
37 absent 1 3 9
39 absent 20 6 9
40 present 91 5 12
42 absent 35 3 13
46 present 139 3 10
48 absent 131 5 13
50 absent 177 2 14
51 absent 68 5 10
57 absent 2 3 13
59 absent 51 7 9
60 absent 102 3 13
66 absent 17 4 10
68 absent 159 4 13
69 absent 18 4 11
71 absent 158 5 14
72 absent 127 4 12
74 absent 206 4 10
77 present 157 3 13
78 absent 26 7 13
79 absent 120 2 13
81 absent 36 4 13
partykit
软件包也为此提供了固定解决方案。您只需要将 rpart
对象转换为 party
class 对象,以便使用其统一接口来处理树。然后就可以使用data_party()
函数了。
使用问题中的 fit
并加载 library("partykit")
您可以首先将 rpart
树强制为 party
:
pfit <- as.party(fit)
plot(pfit)
以您想要的方式提取数据只有两个小麻烦:(1) 原始拟合中的 model.frame()
总是在强制转换中丢失,需要手动重新附加。 (2) 节点采用不同的编号方案。您现在需要节点 4(而不是 5)。
pfit$data <- model.frame(fit)
data4 <- data_party(pfit, 4)
dim(data4)
## [1] 33 5
head(data4)
## Kyphosis Age Start (fitted) (response)
## 2 absent 158 14 7 absent
## 10 present 59 12 8 present
## 11 present 82 14 8 present
## 14 absent 1 12 5 absent
## 18 absent 175 13 7 absent
## 20 absent 27 9 5 absent
另一种方法是从节点 4 开始对子树进行子集化,然后从中获取数据:
pfit4 <- pfit[4]
plot(pfit4)
然后data_party(pfit4)
给你和上面data4
一样的结果。 pfit4$data
为您提供没有 (fitted)
节点和预测的 (response)
.
还有另一种方法,它通过查找任何特定节点的所有终端节点并返回调用中使用的数据子集来实现。
fit <- rpart(Kyphosis ~ Age + Start, data = kyphosis)
head(subset.rpart(fit, 5))
# Kyphosis Age Number Start
# 2 absent 158 3 14
# 10 present 59 6 12
# 11 present 82 5 14
# 14 absent 1 4 12
# 18 absent 175 5 13
# 20 absent 27 4 9
subset.rpart <- function(tree, node = 1L) {
data <- eval(tree$call$data, parent.frame(1L))
wh <- sapply(as.integer(rownames(tree$frame)), parent)
wh <- unique(unlist(wh[sapply(wh, function(x) node %in% x)]))
data[rownames(tree$frame)[tree$where] %in% wh[wh >= node], ]
}
parent <- function(x) {
if (x[1] != 1)
c(Recall(if (x %% 2 == 0L) x / 2 else (x - 1) / 2), x) else x
}
比原版 post 晚了两年,但可能对其他人有用。 rpart 中训练观察的节点分配可以从 $where
获得:
fit <- rpart(Kyphosis ~ Age + Start, data = kyphosis)
fit$where
作为函数:
get_node <- function(rpart.object=fit, data=kyphosis, node.number=5) {
data[which(fit$where == node.number),]
}
get_node()
这仅适用于训练观察,不适用于新观察。