rpart 创建一个 table 指示观察是否属于节点
rpart create a table that indicates if an observation belongs to a node or not
下图是我想做的:
- 为某些数据集
rpart
种植一棵树
- 创建一个 table,原始数据集中每个观察一行,树中每个节点一列,外加一个 ID。如果观察属于该节点,则节点列应取值 1,否则为零。
这是我写的一些代码:
library(rpart)
library(rattle)
data <- kyphosis
fit <- rpart(Age ~ Number + Start, data = kyphosis)
fancyRpartPlot(fit)
nodeNumbers <- as.numeric(rownames(fit$frame))
paths <- path.rpart(fit, nodeNumbers)
for(i in 1:length(nodeNumbers)){
nodeNumber <- nodeNumbers[i]
data[,paste0('gp', nodeNumber)] <- NA
path <- paths[[i]]
if(length(path) == 1) # i.e. we're at the root
data[,paste0('gp', nodeNumber)] <- 1 else
print('help')
}
data
是否有一个包可以满足我的需求? 我能想到的唯一方法是对 paths
使用一些正则表达式魔法目的。我的 guess/hope 是有一种更简单的方法可以做到这一点。
Is there a package out there to do what I need?
AFAIK,不,但这在 rpart
版本 4.1.13
中有效
# function to get the binary matrix OP wants given the leaf index
get_nodes <- function(object, where){
rn <- row.names(object$frame)
edges <- descendants(as.numeric(rn))
o <- t(edges)[where, , drop = FALSE]
colnames(o) <- paste0("GP", rn)
o
}
environment(get_nodes) <- environment(rpart)
# use function
nodes <- get_nodes(fit, fit$where)
head(nodes, 9)
#R GP1 GP2 GP3 GP6 GP7 GP14 GP15
#R [1,] TRUE FALSE TRUE FALSE TRUE TRUE FALSE
#R [2,] TRUE FALSE TRUE FALSE TRUE FALSE TRUE
#R [3,] TRUE FALSE TRUE FALSE TRUE TRUE FALSE
#R [4,] TRUE TRUE FALSE FALSE FALSE FALSE FALSE
#R [5,] TRUE FALSE TRUE FALSE TRUE FALSE TRUE
#R [6,] TRUE FALSE TRUE TRUE FALSE FALSE FALSE
#R [7,] TRUE FALSE TRUE TRUE FALSE FALSE FALSE
#R [8,] TRUE FALSE TRUE TRUE FALSE FALSE FALSE
#R [9,] TRUE FALSE TRUE TRUE FALSE FALSE FALSE
# compare with
head(data, 9)
#R Kyphosis Age Number Start
#R 1 absent 71 3 5
#R 2 absent 158 3 14
#R 3 present 128 4 5
#R 4 absent 2 5 1
#R 5 absent 1 4 15
#R 6 absent 1 2 16
#R 7 absent 61 2 17
#R 8 absent 37 3 16
#R 9 absent 113 2 16
这里是适合模型的完整代码,创建一个可以为新数据集获取结束叶的函数,并创建和使用上述函数
# do as OP
library(rpart)
library(rattle)
data <- kyphosis
fit <- rpart(Age ~ Number + Start, data = kyphosis)
fancyRpartPlot(fit)
# function that gives us the leaf index
get_where <- function(object, newdata, na.action = na.pass){
if (is.null(attr(newdata, "terms"))) {
Terms <- delete.response(object$terms)
newdata <- model.frame(Terms, newdata, na.action = na.action,
xlev = attr(object, "xlevels"))
if (!is.null(cl <- attr(Terms, "dataClasses")))
.checkMFClasses(cl, newdata, TRUE)
}
pred.rpart(object, rpart.matrix(newdata))
}
environment(get_where) <- environment(rpart)
# check that we get the correct value
where <- get_where(fit, data)
stopifnot(isTRUE(all.equal(
fit$frame$yval[where], unname(predict(fit, newdata = data)))))
# function to get the binary matrix OP wants given the leaf index
get_nodes <- function(object, where){
rn <- row.names(object$frame)
edges <- descendants(as.numeric(rn))
o <- t(edges)[where, , drop = FALSE]
colnames(o) <- paste0("GP", rn)
o
}
environment(get_nodes) <- environment(rpart)
# use function
nodes <- get_nodes(fit, where)
head(nodes, 9)
#R GP1 GP2 GP3 GP6 GP7 GP14 GP15
#R [1,] TRUE FALSE TRUE FALSE TRUE TRUE FALSE
#R [2,] TRUE FALSE TRUE FALSE TRUE FALSE TRUE
#R [3,] TRUE FALSE TRUE FALSE TRUE TRUE FALSE
#R [4,] TRUE TRUE FALSE FALSE FALSE FALSE FALSE
#R [5,] TRUE FALSE TRUE FALSE TRUE FALSE TRUE
#R [6,] TRUE FALSE TRUE TRUE FALSE FALSE FALSE
#R [7,] TRUE FALSE TRUE TRUE FALSE FALSE FALSE
#R [8,] TRUE FALSE TRUE TRUE FALSE FALSE FALSE
#R [9,] TRUE FALSE TRUE TRUE FALSE FALSE FALSE
# compare with
head(data, 9)
#R Kyphosis Age Number Start
#R 1 absent 71 3 5
#R 2 absent 158 3 14
#R 3 present 128 4 5
#R 4 absent 2 5 1
#R 5 absent 1 4 15
#R 6 absent 1 2 16
#R 7 absent 61 2 17
#R 8 absent 37 3 16
#R 9 absent 113 2 16
代码来自 rpart:::predict.rpart
和 rpart::path.rpart
。当然,您可以根据需要合并 get_where
和 get_nodes
函数。
下图是我想做的:
- 为某些数据集
rpart
种植一棵树 - 创建一个 table,原始数据集中每个观察一行,树中每个节点一列,外加一个 ID。如果观察属于该节点,则节点列应取值 1,否则为零。
这是我写的一些代码:
library(rpart)
library(rattle)
data <- kyphosis
fit <- rpart(Age ~ Number + Start, data = kyphosis)
fancyRpartPlot(fit)
nodeNumbers <- as.numeric(rownames(fit$frame))
paths <- path.rpart(fit, nodeNumbers)
for(i in 1:length(nodeNumbers)){
nodeNumber <- nodeNumbers[i]
data[,paste0('gp', nodeNumber)] <- NA
path <- paths[[i]]
if(length(path) == 1) # i.e. we're at the root
data[,paste0('gp', nodeNumber)] <- 1 else
print('help')
}
data
是否有一个包可以满足我的需求? 我能想到的唯一方法是对 paths
使用一些正则表达式魔法目的。我的 guess/hope 是有一种更简单的方法可以做到这一点。
Is there a package out there to do what I need?
AFAIK,不,但这在 rpart
版本 4.1.13
# function to get the binary matrix OP wants given the leaf index
get_nodes <- function(object, where){
rn <- row.names(object$frame)
edges <- descendants(as.numeric(rn))
o <- t(edges)[where, , drop = FALSE]
colnames(o) <- paste0("GP", rn)
o
}
environment(get_nodes) <- environment(rpart)
# use function
nodes <- get_nodes(fit, fit$where)
head(nodes, 9)
#R GP1 GP2 GP3 GP6 GP7 GP14 GP15
#R [1,] TRUE FALSE TRUE FALSE TRUE TRUE FALSE
#R [2,] TRUE FALSE TRUE FALSE TRUE FALSE TRUE
#R [3,] TRUE FALSE TRUE FALSE TRUE TRUE FALSE
#R [4,] TRUE TRUE FALSE FALSE FALSE FALSE FALSE
#R [5,] TRUE FALSE TRUE FALSE TRUE FALSE TRUE
#R [6,] TRUE FALSE TRUE TRUE FALSE FALSE FALSE
#R [7,] TRUE FALSE TRUE TRUE FALSE FALSE FALSE
#R [8,] TRUE FALSE TRUE TRUE FALSE FALSE FALSE
#R [9,] TRUE FALSE TRUE TRUE FALSE FALSE FALSE
# compare with
head(data, 9)
#R Kyphosis Age Number Start
#R 1 absent 71 3 5
#R 2 absent 158 3 14
#R 3 present 128 4 5
#R 4 absent 2 5 1
#R 5 absent 1 4 15
#R 6 absent 1 2 16
#R 7 absent 61 2 17
#R 8 absent 37 3 16
#R 9 absent 113 2 16
这里是适合模型的完整代码,创建一个可以为新数据集获取结束叶的函数,并创建和使用上述函数
# do as OP
library(rpart)
library(rattle)
data <- kyphosis
fit <- rpart(Age ~ Number + Start, data = kyphosis)
fancyRpartPlot(fit)
# function that gives us the leaf index
get_where <- function(object, newdata, na.action = na.pass){
if (is.null(attr(newdata, "terms"))) {
Terms <- delete.response(object$terms)
newdata <- model.frame(Terms, newdata, na.action = na.action,
xlev = attr(object, "xlevels"))
if (!is.null(cl <- attr(Terms, "dataClasses")))
.checkMFClasses(cl, newdata, TRUE)
}
pred.rpart(object, rpart.matrix(newdata))
}
environment(get_where) <- environment(rpart)
# check that we get the correct value
where <- get_where(fit, data)
stopifnot(isTRUE(all.equal(
fit$frame$yval[where], unname(predict(fit, newdata = data)))))
# function to get the binary matrix OP wants given the leaf index
get_nodes <- function(object, where){
rn <- row.names(object$frame)
edges <- descendants(as.numeric(rn))
o <- t(edges)[where, , drop = FALSE]
colnames(o) <- paste0("GP", rn)
o
}
environment(get_nodes) <- environment(rpart)
# use function
nodes <- get_nodes(fit, where)
head(nodes, 9)
#R GP1 GP2 GP3 GP6 GP7 GP14 GP15
#R [1,] TRUE FALSE TRUE FALSE TRUE TRUE FALSE
#R [2,] TRUE FALSE TRUE FALSE TRUE FALSE TRUE
#R [3,] TRUE FALSE TRUE FALSE TRUE TRUE FALSE
#R [4,] TRUE TRUE FALSE FALSE FALSE FALSE FALSE
#R [5,] TRUE FALSE TRUE FALSE TRUE FALSE TRUE
#R [6,] TRUE FALSE TRUE TRUE FALSE FALSE FALSE
#R [7,] TRUE FALSE TRUE TRUE FALSE FALSE FALSE
#R [8,] TRUE FALSE TRUE TRUE FALSE FALSE FALSE
#R [9,] TRUE FALSE TRUE TRUE FALSE FALSE FALSE
# compare with
head(data, 9)
#R Kyphosis Age Number Start
#R 1 absent 71 3 5
#R 2 absent 158 3 14
#R 3 present 128 4 5
#R 4 absent 2 5 1
#R 5 absent 1 4 15
#R 6 absent 1 2 16
#R 7 absent 61 2 17
#R 8 absent 37 3 16
#R 9 absent 113 2 16
代码来自 rpart:::predict.rpart
和 rpart::path.rpart
。当然,您可以根据需要合并 get_where
和 get_nodes
函数。