计算R中决策树中每个节点的观察数?

Calculate number of observations in each node in a decision tree in R?

类似的问题已经被问到,例如 here and here 但 none 其他问题可以应用于我的问题。我试图确定和计算决策树中每个节点中的观察值。但是,树结构来自我从 BART 包中创建的树数据框。我正在从 BART 包中提取树信息并将其转换为类似于下图所示的数据框(即 df)。但我需要使用提供的数据框架结构。旁白:我相信我使用的方法与我的数据框中的树 drawn/ordered 有关,称为 'depth first'.

例如,我的树木数据框如下所示:

library(dplyr)
df <- tibble(variableName = c("x2", "x1", NA, NA, NA, "x2", NA, NA, "x5", "x4", NA, NA, "x3", NA, NA),
             splitValue = c(0.542, 0.126, NA, NA, NA, 0.6547, NA, NA, 0.418, 0.234, NA, NA, 0.747, NA, NA),
             treeNo = c(1,1,1,1,1,2,2,2,3,3,3,3,3,3,3))

从视觉上看,这些树看起来像:

向下遍历df时正在绘制树left-first。此外,所有拆分都是二进制拆分。所以每个节点都会有2个children.

因此,如果我们创建一些如下所示的数据:

set.seed(100)
dat <- data.frame( x1 = runif(10),
                   x2 = runif(10),
                   x3 = runif(10),
                   x4 = runif(10),
                   x5 = runif(10)
)

我正在尝试找出 dat 的哪些观察结果属于哪个节点?

尝试回答: 这并不是很有帮助,但为了清楚起见(因为我仍在努力解决这个问题),为三号树硬编码它看起来像这样:

lists <- df %>% group_by(treeNo) %>% group_split()
tree<- lists[[3]]

 namesDf <- names(dat[grepl(tree[1, ]$variableName, names(dat))])
    dataLeft <- dat[dat[, namesDf] <= tree[1,]$splitValue, ]
    dataRight <- dat[dat[, namesDf] > tree[1,]$splitValue, ]
    
    namesDf <- names(dat[grepl(tree[2, ]$variableName, names(dat))])
    dataLeft1 <- dataLeft[dataLeft[, namesDf] <= tree[2,]$splitValue, ]
    dataRight1 <- dataLeft[dataLeft[, namesDf] > tree[2,]$splitValue, ]
    
    namesDf <- names(dat[grepl(tree[5, ]$variableName, names(dat))])
    dataLeft2 <- dataRight[dataRight[, namesDf] <= tree[5,]$splitValue, ]
    dataRight2 <- dataRight[dataRight[, namesDf] > tree[5,]$splitValue, ]

我一直在尝试将其变成一个循环。但事实证明,锻炼起来很有挑战性。 而且我(显然)不能为每棵树硬编码。关于如何解决这个问题有什么建议吗??

还有很大的优化空间,不过这是我的尝试。您的树似乎以 depth-first 方式构建,左侧 children 始终跟随 parent 节点:

library(dplyr)
df <- tibble(variableName = c("x2", "x1", NA, NA, NA, "x2", NA, NA, "x5", "x4", NA, NA, "x3", NA, NA),
             splitValue = c(0.542, 0.126, NA, NA, NA, 0.6547, NA, NA, 0.418, 0.234, NA, NA, 0.747, NA, NA),
             treeNo = c(1,1,1,1,1,2,2,2,3,3,3,3,3,3,3))

给定要匹配的数据:

set.seed(100)
dat <- data.frame( x1 = runif(10),
                   x2 = runif(10),
                   x3 = runif(10),
                   x4 = runif(10),
                   x5 = runif(10)
)
dat
##>           x1        x2        x3        x4        x5
##>1  0.30776611 0.6249965 0.5358112 0.4883060 0.3306605
##>2  0.25767250 0.8821655 0.7108038 0.9285051 0.8651205
##>3  0.55232243 0.2803538 0.5383487 0.3486920 0.7775844
##>4  0.05638315 0.3984879 0.7489722 0.9541577 0.8273034
##>5  0.46854928 0.7625511 0.4201015 0.6952741 0.6033244
##>6  0.48377074 0.6690217 0.1714202 0.8894535 0.4912318
##>7  0.81240262 0.2046122 0.7703016 0.1804072 0.7803585
##>8  0.37032054 0.3575249 0.8819536 0.6293909 0.8842270
##>9  0.54655860 0.3594751 0.5490967 0.9895641 0.2077139
##>10 0.17026205 0.6902905 0.2777238 0.1302889 0.3070859

makeTree 是一个 higher-order 函数,returns 一个函数反过来将一行值映射到一个节点:

makeTree <- function(dat, r = 1) {
  ## the argument dat is a dataframe representation
  ## of a single tree as in the example
  ## return a list of two elements: size and fn. 
  ## - size is the number of cells taken by the 
  ##   node and its descendants. 
  ## - fn is a function of one argument (either a list or
  ##   a row of a dataframe) that returns the index of the 
  ##   node matching argument. More precisely the column Id
  ##   in dat.    
  stopifnot(r <= nrow(dat))
  vname <- pull(dat,variableName)[r]
  splitVal <- pull(dat, splitValue)[r]
  if (is.na(vname)) {
    ## terminal node
    ## print(sprintf("terminal node: %i", r))
    res <- list(size = 1, # offset to access right node
                fn = function(z) {
                  pull(dat, "id")[r]
                })
    return(res)
  } else {
    ##print(sprintf("node: %i, varName: %s, splitVal: %f", r, vname, splitVal ))
    ## compute the left and right functions
    ## note that the tree is traversed depth-first 
    fnleft <- makeTree(dat, r + 1) #fnleft is always positoned next to the
                                   #caller
    fnright <- makeTree(dat, r + fnleft$size + 1 )
    return(list(size = fnleft$size + fnright$size + 1,
                fn = function(z) {
                  if (z[vname] <= splitVal)
                    fnleft$fn(z)
                  else
                    fnright$fn(z)
                }))
  }
}

现在 makeTree 应用于每棵树以生成匹配函数列表:

treefns <- df |>
  mutate(id = row_number()) %>%
  group_by(treeNo) |>
  group_split()    |>
  purrr::map(makeTree) |>
  purrr::map("fn")

最后,数据帧的每一行 dat 都与树的一个节点相匹配:

apply(dat,1, function(z) sapply(treefns, function(fn) fn(z))) |>
  t() |>
  data.frame() |>
  rename_with(function(z) paste0("TREE", gsub("X", "", z))) |>
  cbind(dat) |>
  pivot_longer(cols = starts_with("TREE"),
               names_to = "TREE",
               values_to = "NODE")  |>
  sample_n(10)

##> A tibble: 10 x 7
##>       x1    x2    x3    x4    x5 TREE   NODE
##>    <dbl> <dbl> <dbl> <dbl> <dbl> <chr> <int>
##> 1 0.170  0.690 0.278 0.130 0.307 TREE3    11
##> 2 0.170  0.690 0.278 0.130 0.307 TREE2     8
##> 3 0.370  0.358 0.882 0.629 0.884 TREE2     7
##> 4 0.308  0.625 0.536 0.488 0.331 TREE1     5
##> 5 0.370  0.358 0.882 0.629 0.884 TREE1     4
##> 6 0.552  0.280 0.538 0.349 0.778 TREE3    14
##> 7 0.547  0.359 0.549 0.990 0.208 TREE1     4
##> 8 0.370  0.358 0.882 0.629 0.884 TREE3    15
##> 9 0.547  0.359 0.549 0.990 0.208 TREE2     7
##>10 0.0564 0.398 0.749 0.954 0.827 TREE2     7

看来我们可以做“滚动拆分”来得到你要找的东西。逻辑如下

  1. 从只有一个数据帧的堆栈开始 dat
  2. 对于每对 variableNamesplitValue,如果它们不是 NA,则将该堆栈上的顶部数据帧拆分为两个由 variableName <= splitValue 标识的子数据帧和 variableName > splitValue(前者在后者之上);如果它们是 NAs,那么只需弹出顶部数据框。

这是代码。请注意,这种 state-dependent 计算很难矢量化。因此,这不是 R 擅长的。如果你有很多树并且代码性能成为一个严重的问题,我建议使用 Rcpp.

重写下面的代码
eval_node <- function(df, x, v) {
  out <- vector("list", length(x))
  stk <- vector("list", sum(is.na(x)))
  pos <- 1L
  stk[[pos]] <- df
  for (i in seq_along(x)) {
    if (!is.na(x[[i]])) {
      subs <- pos + c(0L, 1L)
      stk[subs] <- split(stk[[pos]], stk[[pos]][[x[[i]]]] <= v[[i]])
      names(stk)[subs] <- trimws(paste0(
        names(stk[pos]), ",", x[[i]], c(">", "<="), v[[i]]
      ), "left", ",")
      out[[i]] <- rev(stk[subs])
      pos <- pos + 1L
    } else {
      out[[i]] <- stk[pos]
      stk[[pos]] <- NULL
      pos <- pos - 1L
    }
  }
  out
}

然后你可以像这样应用函数。

library(dplyr)

df %>% group_by(treeNo) %>% mutate(node = eval_node(dat, variableName, splitValue))

输出

# A tibble: 15 x 4
# Groups:   treeNo [3]
   variableName splitValue treeNo node            
   <chr>             <dbl>  <dbl> <list>          
 1 x2                0.542      1 <named list [2]>
 2 x1                0.126      1 <named list [2]>
 3 NA               NA          1 <named list [1]>
 4 NA               NA          1 <named list [1]>
 5 NA               NA          1 <named list [1]>
 6 x2                0.655      2 <named list [2]>
 7 NA               NA          2 <named list [1]>
 8 NA               NA          2 <named list [1]>
 9 x5                0.418      3 <named list [2]>
10 x4                0.234      3 <named list [2]>
11 NA               NA          3 <named list [1]>
12 NA               NA          3 <named list [1]>
13 x3                0.747      3 <named list [2]>
14 NA               NA          3 <named list [1]>
15 NA               NA          3 <named list [1]>

,其中 node 看起来像这样

[[1]]
[[1]]$`x2<=0.542`
          x1        x2        x3        x4        x5
3 0.55232243 0.2803538 0.5383487 0.3486920 0.7775844
4 0.05638315 0.3984879 0.7489722 0.9541577 0.8273034
7 0.81240262 0.2046122 0.7703016 0.1804072 0.7803585
8 0.37032054 0.3575249 0.8819536 0.6293909 0.8842270
9 0.54655860 0.3594751 0.5490967 0.9895641 0.2077139

[[1]]$`x2>0.542`
          x1        x2        x3        x4        x5
1  0.3077661 0.6249965 0.5358112 0.4883060 0.3306605
2  0.2576725 0.8821655 0.7108038 0.9285051 0.8651205
5  0.4685493 0.7625511 0.4201015 0.6952741 0.6033244
6  0.4837707 0.6690217 0.1714202 0.8894535 0.4912318
10 0.1702621 0.6902905 0.2777238 0.1302889 0.3070859


[[2]]
[[2]]$`x2<=0.542,x1<=0.126`
          x1        x2        x3        x4        x5
4 0.05638315 0.3984879 0.7489722 0.9541577 0.8273034

[[2]]$`x2<=0.542,x1>0.126`
         x1        x2        x3        x4        x5
3 0.5523224 0.2803538 0.5383487 0.3486920 0.7775844
7 0.8124026 0.2046122 0.7703016 0.1804072 0.7803585
8 0.3703205 0.3575249 0.8819536 0.6293909 0.8842270
9 0.5465586 0.3594751 0.5490967 0.9895641 0.2077139


[[3]]
[[3]]$`x2<=0.542,x1<=0.126`
          x1        x2        x3        x4        x5
4 0.05638315 0.3984879 0.7489722 0.9541577 0.8273034


[[4]]
[[4]]$`x2<=0.542,x1>0.126`
         x1        x2        x3        x4        x5
3 0.5523224 0.2803538 0.5383487 0.3486920 0.7775844
7 0.8124026 0.2046122 0.7703016 0.1804072 0.7803585
8 0.3703205 0.3575249 0.8819536 0.6293909 0.8842270
9 0.5465586 0.3594751 0.5490967 0.9895641 0.2077139


[[5]]
[[5]]$`x2>0.542`
          x1        x2        x3        x4        x5
1  0.3077661 0.6249965 0.5358112 0.4883060 0.3306605
2  0.2576725 0.8821655 0.7108038 0.9285051 0.8651205
5  0.4685493 0.7625511 0.4201015 0.6952741 0.6033244
6  0.4837707 0.6690217 0.1714202 0.8894535 0.4912318
10 0.1702621 0.6902905 0.2777238 0.1302889 0.3070859


[[6]]
[[6]]$`x2<=0.6547`
          x1        x2        x3        x4        x5
1 0.30776611 0.6249965 0.5358112 0.4883060 0.3306605
3 0.55232243 0.2803538 0.5383487 0.3486920 0.7775844
4 0.05638315 0.3984879 0.7489722 0.9541577 0.8273034
7 0.81240262 0.2046122 0.7703016 0.1804072 0.7803585
8 0.37032054 0.3575249 0.8819536 0.6293909 0.8842270
9 0.54655860 0.3594751 0.5490967 0.9895641 0.2077139

[[6]]$`x2>0.6547`
          x1        x2        x3        x4        x5
2  0.2576725 0.8821655 0.7108038 0.9285051 0.8651205
5  0.4685493 0.7625511 0.4201015 0.6952741 0.6033244
6  0.4837707 0.6690217 0.1714202 0.8894535 0.4912318
10 0.1702621 0.6902905 0.2777238 0.1302889 0.3070859


[[7]]
[[7]]$`x2<=0.6547`
          x1        x2        x3        x4        x5
1 0.30776611 0.6249965 0.5358112 0.4883060 0.3306605
3 0.55232243 0.2803538 0.5383487 0.3486920 0.7775844
4 0.05638315 0.3984879 0.7489722 0.9541577 0.8273034
7 0.81240262 0.2046122 0.7703016 0.1804072 0.7803585
8 0.37032054 0.3575249 0.8819536 0.6293909 0.8842270
9 0.54655860 0.3594751 0.5490967 0.9895641 0.2077139


[[8]]
[[8]]$`x2>0.6547`
          x1        x2        x3        x4        x5
2  0.2576725 0.8821655 0.7108038 0.9285051 0.8651205
5  0.4685493 0.7625511 0.4201015 0.6952741 0.6033244
6  0.4837707 0.6690217 0.1714202 0.8894535 0.4912318
10 0.1702621 0.6902905 0.2777238 0.1302889 0.3070859


[[9]]
[[9]]$`x5<=0.418`
          x1        x2        x3        x4        x5
1  0.3077661 0.6249965 0.5358112 0.4883060 0.3306605
9  0.5465586 0.3594751 0.5490967 0.9895641 0.2077139
10 0.1702621 0.6902905 0.2777238 0.1302889 0.3070859

[[9]]$`x5>0.418`
          x1        x2        x3        x4        x5
2 0.25767250 0.8821655 0.7108038 0.9285051 0.8651205
3 0.55232243 0.2803538 0.5383487 0.3486920 0.7775844
4 0.05638315 0.3984879 0.7489722 0.9541577 0.8273034
5 0.46854928 0.7625511 0.4201015 0.6952741 0.6033244
6 0.48377074 0.6690217 0.1714202 0.8894535 0.4912318
7 0.81240262 0.2046122 0.7703016 0.1804072 0.7803585
8 0.37032054 0.3575249 0.8819536 0.6293909 0.8842270


[[10]]
[[10]]$`x5<=0.418,x4<=0.234`
          x1        x2        x3        x4        x5
10 0.1702621 0.6902905 0.2777238 0.1302889 0.3070859

[[10]]$`x5<=0.418,x4>0.234`
         x1        x2        x3        x4        x5
1 0.3077661 0.6249965 0.5358112 0.4883060 0.3306605
9 0.5465586 0.3594751 0.5490967 0.9895641 0.2077139


[[11]]
[[11]]$`x5<=0.418,x4<=0.234`
          x1        x2        x3        x4        x5
10 0.1702621 0.6902905 0.2777238 0.1302889 0.3070859


[[12]]
[[12]]$`x5<=0.418,x4>0.234`
         x1        x2        x3        x4        x5
1 0.3077661 0.6249965 0.5358112 0.4883060 0.3306605
9 0.5465586 0.3594751 0.5490967 0.9895641 0.2077139


[[13]]
[[13]]$`x5>0.418,x3<=0.747`
         x1        x2        x3        x4        x5
2 0.2576725 0.8821655 0.7108038 0.9285051 0.8651205
3 0.5523224 0.2803538 0.5383487 0.3486920 0.7775844
5 0.4685493 0.7625511 0.4201015 0.6952741 0.6033244
6 0.4837707 0.6690217 0.1714202 0.8894535 0.4912318

[[13]]$`x5>0.418,x3>0.747`
          x1        x2        x3        x4        x5
4 0.05638315 0.3984879 0.7489722 0.9541577 0.8273034
7 0.81240262 0.2046122 0.7703016 0.1804072 0.7803585
8 0.37032054 0.3575249 0.8819536 0.6293909 0.8842270


[[14]]
[[14]]$`x5>0.418,x3<=0.747`
         x1        x2        x3        x4        x5
2 0.2576725 0.8821655 0.7108038 0.9285051 0.8651205
3 0.5523224 0.2803538 0.5383487 0.3486920 0.7775844
5 0.4685493 0.7625511 0.4201015 0.6952741 0.6033244
6 0.4837707 0.6690217 0.1714202 0.8894535 0.4912318


[[15]]
[[15]]$`x5>0.418,x3>0.747`
          x1        x2        x3        x4        x5
4 0.05638315 0.3984879 0.7489722 0.9541577 0.8273034
7 0.81240262 0.2046122 0.7703016 0.1804072 0.7803585
8 0.37032054 0.3575249 0.8819536 0.6293909 0.8842270