向树添加信息 - Rpart

Adding informations to tree - Rpart

我想为我的树添加一些信息。例如,假设我有一个这样的数据库:

library(rpart)
library(rpart.plot)
set.seed(1)
mydb<-data.frame(results=rnorm(1000,0,1),expo=runif(1000),var1=sample(LETTERS[1:4],1000,replace=T),
                 var2=sample(LETTERS[5:6],1000,replace=T),var3=sample(LETTERS[20:25], 1000,replace=T))

我可以运行一棵树:

mytree<-rpart(results~var1+var2+var3,data=mydb,cp=0)
pfit<- prune(mytree, cp=mytree$cptable[4,"CP"])
prp(pfit,type=1,extra=100,fallen.leaves=F,shadow.col="darkgray",box.col=rgb(0.8,0.9,0.8))

结果如下所示:

这对我来说没问题,但假设我想知道每片叶子的平均曝光量。

我知道我可以在 prp 中添加一些信息,例如每片叶子的重量和函数:

node.fun1 <- function(x, labs, digits, varlen)
{
  paste("Weight \n",x$frame$wt)
}

prp(pfit,type=1,extra=100,fallen.leaves=F,shadow.col="darkgray",box.col=rgb(0.8,0.9,0.8),node.fun = node.fun1)

但只有在框架中计算时才有效,即 rpart 函数的结果。

我的问题:

如何将自定义信息添加到图中,例如平均曝光率,或任何其他计算自定义指标并将其添加到 table frame 的函数?

这真的很好,我不知道这是一个选项。

所有的工作似乎都是在获取每个节点上使用的原始数据的子集。这对于终端节点来说很容易,但我没有找到一种直接的方法来识别每个节点中使用的数据行,而不仅仅是叶子。如果有人知道更简单的方法,我很想听听。

library('rpart.plot')
set.seed(1)
mydb<-data.frame(results=rnorm(1000,0,1),expo=runif(1000),var1=sample(LETTERS[1:4],1000,replace=T),
                 var2=sample(LETTERS[5:6],1000,replace=T),var3=sample(LETTERS[20:25], 1000,replace=T))
mytree<-rpart(results~var1+var2+var3,data=mydb,cp=0)
pfit<- prune(mytree, cp=mytree$cptable[4,"CP"])

rpart.plot(pfit)

定义接受 x 的新函数,拟合 rpart 的结果(我没有研究其他参数,但小插图应该会有所帮助)。

对于 x$frame 的每一行,我们需要获取用于计算汇总统计数据的数据。不幸的是,x$where 只告诉我们每个观测值所在的终端节点。所以对于每个节点号,我们使用subset.rpart获取底层数据,想怎么用就怎么用

f <- function(x, labs, digits, varlen) {
  nodes <- as.integer(rownames(x$frame))
  z <- sapply(nodes, function(y) {
    data <- subset.rpart(x, y)
    c(mean = mean(data$expo), nrow(data), nrow(data) / length(x$where) * 100)
  })
  sprintf('Mean expo: %.2f\nn=%.0f (%.0f%%)', z[1, ], z[2, ], z[3, ])
}

prp(pfit, type=1, extra=100, fallen.leaves=FALSE,
    shadow.col="darkgray", box.col=rgb(0.8,0.9,0.8),
    node.fun = f)

这项工作是由 subset.rpart 完成的,它采用节点编号和 returns 节点上使用的 data 的子集。

subset.rpart <- function(tree, node = 1L) {
  ## returns subset of tree$call$data used on any node
  data <- eval(tree$call$data, parent.frame(1L))
  wh <- sapply(as.integer(rownames(tree$frame)), parent)
  wh <- unique(unlist(wh[sapply(wh, function(x) node %in% x)]))
  data[rownames(tree$frame)[tree$where] %in% wh[wh >= node], ]
}

parent <- function(x) {
  ## returns vector of parent nodes
  if (x[1] != 1)
    c(Recall(if (x %% 2 == 0L) x / 2 else (x - 1) / 2), x) else x
}

测试

## tests
dim(subset.rpart(pfit, 1)) == dim(mydb)
# [1] TRUE TRUE

## terminal nodes
nodes <- as.integer(rownames(pfit$frame[pfit$frame$var %in% '<leaf>', ]))
sum(sapply(nodes, function(x) nrow(subset.rpart(pfit, x)))) == nrow(mydb)
# [1] TRUE

我不知道这是否正是您想要的,但请尝试 "sparkline" 和 "visNetwork" 包。他们使用 rpart 对象