计算R中决策树中每个节点的观察数?
Calculate number of observations in each node in a decision tree in R?
类似的问题已经被问到,例如 here and here 但 none 其他问题可以应用于我的问题。我试图确定和计算决策树中每个节点中的观察值。但是,树结构来自我从 BART
包中创建的树数据框。我正在从 BART
包中提取树信息并将其转换为类似于下图所示的数据框(即 df
)。但我需要使用提供的数据框架结构。旁白:我相信我使用的方法与我的数据框中的树 drawn/ordered 有关,称为 'depth first'.
例如,我的树木数据框如下所示:
library(dplyr)
df <- tibble(variableName = c("x2", "x1", NA, NA, NA, "x2", NA, NA, "x5", "x4", NA, NA, "x3", NA, NA),
splitValue = c(0.542, 0.126, NA, NA, NA, 0.6547, NA, NA, 0.418, 0.234, NA, NA, 0.747, NA, NA),
treeNo = c(1,1,1,1,1,2,2,2,3,3,3,3,3,3,3))
从视觉上看,这些树看起来像:
向下遍历df
时正在绘制树left-first。此外,所有拆分都是二进制拆分。所以每个节点都会有2个children.
因此,如果我们创建一些如下所示的数据:
set.seed(100)
dat <- data.frame( x1 = runif(10),
x2 = runif(10),
x3 = runif(10),
x4 = runif(10),
x5 = runif(10)
)
我正在尝试找出 dat
的哪些观察结果属于哪个节点?
尝试回答:
这并不是很有帮助,但为了清楚起见(因为我仍在努力解决这个问题),为三号树硬编码它看起来像这样:
lists <- df %>% group_by(treeNo) %>% group_split()
tree<- lists[[3]]
namesDf <- names(dat[grepl(tree[1, ]$variableName, names(dat))])
dataLeft <- dat[dat[, namesDf] <= tree[1,]$splitValue, ]
dataRight <- dat[dat[, namesDf] > tree[1,]$splitValue, ]
namesDf <- names(dat[grepl(tree[2, ]$variableName, names(dat))])
dataLeft1 <- dataLeft[dataLeft[, namesDf] <= tree[2,]$splitValue, ]
dataRight1 <- dataLeft[dataLeft[, namesDf] > tree[2,]$splitValue, ]
namesDf <- names(dat[grepl(tree[5, ]$variableName, names(dat))])
dataLeft2 <- dataRight[dataRight[, namesDf] <= tree[5,]$splitValue, ]
dataRight2 <- dataRight[dataRight[, namesDf] > tree[5,]$splitValue, ]
我一直在尝试将其变成一个循环。但事实证明,锻炼起来很有挑战性。
而且我(显然)不能为每棵树硬编码。关于如何解决这个问题有什么建议吗??
还有很大的优化空间,不过这是我的尝试。您的树似乎以 depth-first 方式构建,左侧 children 始终跟随 parent 节点:
library(dplyr)
df <- tibble(variableName = c("x2", "x1", NA, NA, NA, "x2", NA, NA, "x5", "x4", NA, NA, "x3", NA, NA),
splitValue = c(0.542, 0.126, NA, NA, NA, 0.6547, NA, NA, 0.418, 0.234, NA, NA, 0.747, NA, NA),
treeNo = c(1,1,1,1,1,2,2,2,3,3,3,3,3,3,3))
给定要匹配的数据:
set.seed(100)
dat <- data.frame( x1 = runif(10),
x2 = runif(10),
x3 = runif(10),
x4 = runif(10),
x5 = runif(10)
)
dat
##> x1 x2 x3 x4 x5
##>1 0.30776611 0.6249965 0.5358112 0.4883060 0.3306605
##>2 0.25767250 0.8821655 0.7108038 0.9285051 0.8651205
##>3 0.55232243 0.2803538 0.5383487 0.3486920 0.7775844
##>4 0.05638315 0.3984879 0.7489722 0.9541577 0.8273034
##>5 0.46854928 0.7625511 0.4201015 0.6952741 0.6033244
##>6 0.48377074 0.6690217 0.1714202 0.8894535 0.4912318
##>7 0.81240262 0.2046122 0.7703016 0.1804072 0.7803585
##>8 0.37032054 0.3575249 0.8819536 0.6293909 0.8842270
##>9 0.54655860 0.3594751 0.5490967 0.9895641 0.2077139
##>10 0.17026205 0.6902905 0.2777238 0.1302889 0.3070859
makeTree
是一个 higher-order 函数,returns 一个函数反过来将一行值映射到一个节点:
makeTree <- function(dat, r = 1) {
## the argument dat is a dataframe representation
## of a single tree as in the example
## return a list of two elements: size and fn.
## - size is the number of cells taken by the
## node and its descendants.
## - fn is a function of one argument (either a list or
## a row of a dataframe) that returns the index of the
## node matching argument. More precisely the column Id
## in dat.
stopifnot(r <= nrow(dat))
vname <- pull(dat,variableName)[r]
splitVal <- pull(dat, splitValue)[r]
if (is.na(vname)) {
## terminal node
## print(sprintf("terminal node: %i", r))
res <- list(size = 1, # offset to access right node
fn = function(z) {
pull(dat, "id")[r]
})
return(res)
} else {
##print(sprintf("node: %i, varName: %s, splitVal: %f", r, vname, splitVal ))
## compute the left and right functions
## note that the tree is traversed depth-first
fnleft <- makeTree(dat, r + 1) #fnleft is always positoned next to the
#caller
fnright <- makeTree(dat, r + fnleft$size + 1 )
return(list(size = fnleft$size + fnright$size + 1,
fn = function(z) {
if (z[vname] <= splitVal)
fnleft$fn(z)
else
fnright$fn(z)
}))
}
}
现在 makeTree
应用于每棵树以生成匹配函数列表:
treefns <- df |>
mutate(id = row_number()) %>%
group_by(treeNo) |>
group_split() |>
purrr::map(makeTree) |>
purrr::map("fn")
最后,数据帧的每一行 dat
都与树的一个节点相匹配:
apply(dat,1, function(z) sapply(treefns, function(fn) fn(z))) |>
t() |>
data.frame() |>
rename_with(function(z) paste0("TREE", gsub("X", "", z))) |>
cbind(dat) |>
pivot_longer(cols = starts_with("TREE"),
names_to = "TREE",
values_to = "NODE") |>
sample_n(10)
##> A tibble: 10 x 7
##> x1 x2 x3 x4 x5 TREE NODE
##> <dbl> <dbl> <dbl> <dbl> <dbl> <chr> <int>
##> 1 0.170 0.690 0.278 0.130 0.307 TREE3 11
##> 2 0.170 0.690 0.278 0.130 0.307 TREE2 8
##> 3 0.370 0.358 0.882 0.629 0.884 TREE2 7
##> 4 0.308 0.625 0.536 0.488 0.331 TREE1 5
##> 5 0.370 0.358 0.882 0.629 0.884 TREE1 4
##> 6 0.552 0.280 0.538 0.349 0.778 TREE3 14
##> 7 0.547 0.359 0.549 0.990 0.208 TREE1 4
##> 8 0.370 0.358 0.882 0.629 0.884 TREE3 15
##> 9 0.547 0.359 0.549 0.990 0.208 TREE2 7
##>10 0.0564 0.398 0.749 0.954 0.827 TREE2 7
看来我们可以做“滚动拆分”来得到你要找的东西。逻辑如下
- 从只有一个数据帧的堆栈开始
dat
。
- 对于每对
variableName
和 splitValue
,如果它们不是 NA
,则将该堆栈上的顶部数据帧拆分为两个由 variableName <= splitValue
标识的子数据帧和 variableName > splitValue
(前者在后者之上);如果它们是 NA
s,那么只需弹出顶部数据框。
这是代码。请注意,这种 state-dependent 计算很难矢量化。因此,这不是 R 擅长的。如果你有很多树并且代码性能成为一个严重的问题,我建议使用 Rcpp
.
重写下面的代码
eval_node <- function(df, x, v) {
out <- vector("list", length(x))
stk <- vector("list", sum(is.na(x)))
pos <- 1L
stk[[pos]] <- df
for (i in seq_along(x)) {
if (!is.na(x[[i]])) {
subs <- pos + c(0L, 1L)
stk[subs] <- split(stk[[pos]], stk[[pos]][[x[[i]]]] <= v[[i]])
names(stk)[subs] <- trimws(paste0(
names(stk[pos]), ",", x[[i]], c(">", "<="), v[[i]]
), "left", ",")
out[[i]] <- rev(stk[subs])
pos <- pos + 1L
} else {
out[[i]] <- stk[pos]
stk[[pos]] <- NULL
pos <- pos - 1L
}
}
out
}
然后你可以像这样应用函数。
library(dplyr)
df %>% group_by(treeNo) %>% mutate(node = eval_node(dat, variableName, splitValue))
输出
# A tibble: 15 x 4
# Groups: treeNo [3]
variableName splitValue treeNo node
<chr> <dbl> <dbl> <list>
1 x2 0.542 1 <named list [2]>
2 x1 0.126 1 <named list [2]>
3 NA NA 1 <named list [1]>
4 NA NA 1 <named list [1]>
5 NA NA 1 <named list [1]>
6 x2 0.655 2 <named list [2]>
7 NA NA 2 <named list [1]>
8 NA NA 2 <named list [1]>
9 x5 0.418 3 <named list [2]>
10 x4 0.234 3 <named list [2]>
11 NA NA 3 <named list [1]>
12 NA NA 3 <named list [1]>
13 x3 0.747 3 <named list [2]>
14 NA NA 3 <named list [1]>
15 NA NA 3 <named list [1]>
,其中 node
看起来像这样
[[1]]
[[1]]$`x2<=0.542`
x1 x2 x3 x4 x5
3 0.55232243 0.2803538 0.5383487 0.3486920 0.7775844
4 0.05638315 0.3984879 0.7489722 0.9541577 0.8273034
7 0.81240262 0.2046122 0.7703016 0.1804072 0.7803585
8 0.37032054 0.3575249 0.8819536 0.6293909 0.8842270
9 0.54655860 0.3594751 0.5490967 0.9895641 0.2077139
[[1]]$`x2>0.542`
x1 x2 x3 x4 x5
1 0.3077661 0.6249965 0.5358112 0.4883060 0.3306605
2 0.2576725 0.8821655 0.7108038 0.9285051 0.8651205
5 0.4685493 0.7625511 0.4201015 0.6952741 0.6033244
6 0.4837707 0.6690217 0.1714202 0.8894535 0.4912318
10 0.1702621 0.6902905 0.2777238 0.1302889 0.3070859
[[2]]
[[2]]$`x2<=0.542,x1<=0.126`
x1 x2 x3 x4 x5
4 0.05638315 0.3984879 0.7489722 0.9541577 0.8273034
[[2]]$`x2<=0.542,x1>0.126`
x1 x2 x3 x4 x5
3 0.5523224 0.2803538 0.5383487 0.3486920 0.7775844
7 0.8124026 0.2046122 0.7703016 0.1804072 0.7803585
8 0.3703205 0.3575249 0.8819536 0.6293909 0.8842270
9 0.5465586 0.3594751 0.5490967 0.9895641 0.2077139
[[3]]
[[3]]$`x2<=0.542,x1<=0.126`
x1 x2 x3 x4 x5
4 0.05638315 0.3984879 0.7489722 0.9541577 0.8273034
[[4]]
[[4]]$`x2<=0.542,x1>0.126`
x1 x2 x3 x4 x5
3 0.5523224 0.2803538 0.5383487 0.3486920 0.7775844
7 0.8124026 0.2046122 0.7703016 0.1804072 0.7803585
8 0.3703205 0.3575249 0.8819536 0.6293909 0.8842270
9 0.5465586 0.3594751 0.5490967 0.9895641 0.2077139
[[5]]
[[5]]$`x2>0.542`
x1 x2 x3 x4 x5
1 0.3077661 0.6249965 0.5358112 0.4883060 0.3306605
2 0.2576725 0.8821655 0.7108038 0.9285051 0.8651205
5 0.4685493 0.7625511 0.4201015 0.6952741 0.6033244
6 0.4837707 0.6690217 0.1714202 0.8894535 0.4912318
10 0.1702621 0.6902905 0.2777238 0.1302889 0.3070859
[[6]]
[[6]]$`x2<=0.6547`
x1 x2 x3 x4 x5
1 0.30776611 0.6249965 0.5358112 0.4883060 0.3306605
3 0.55232243 0.2803538 0.5383487 0.3486920 0.7775844
4 0.05638315 0.3984879 0.7489722 0.9541577 0.8273034
7 0.81240262 0.2046122 0.7703016 0.1804072 0.7803585
8 0.37032054 0.3575249 0.8819536 0.6293909 0.8842270
9 0.54655860 0.3594751 0.5490967 0.9895641 0.2077139
[[6]]$`x2>0.6547`
x1 x2 x3 x4 x5
2 0.2576725 0.8821655 0.7108038 0.9285051 0.8651205
5 0.4685493 0.7625511 0.4201015 0.6952741 0.6033244
6 0.4837707 0.6690217 0.1714202 0.8894535 0.4912318
10 0.1702621 0.6902905 0.2777238 0.1302889 0.3070859
[[7]]
[[7]]$`x2<=0.6547`
x1 x2 x3 x4 x5
1 0.30776611 0.6249965 0.5358112 0.4883060 0.3306605
3 0.55232243 0.2803538 0.5383487 0.3486920 0.7775844
4 0.05638315 0.3984879 0.7489722 0.9541577 0.8273034
7 0.81240262 0.2046122 0.7703016 0.1804072 0.7803585
8 0.37032054 0.3575249 0.8819536 0.6293909 0.8842270
9 0.54655860 0.3594751 0.5490967 0.9895641 0.2077139
[[8]]
[[8]]$`x2>0.6547`
x1 x2 x3 x4 x5
2 0.2576725 0.8821655 0.7108038 0.9285051 0.8651205
5 0.4685493 0.7625511 0.4201015 0.6952741 0.6033244
6 0.4837707 0.6690217 0.1714202 0.8894535 0.4912318
10 0.1702621 0.6902905 0.2777238 0.1302889 0.3070859
[[9]]
[[9]]$`x5<=0.418`
x1 x2 x3 x4 x5
1 0.3077661 0.6249965 0.5358112 0.4883060 0.3306605
9 0.5465586 0.3594751 0.5490967 0.9895641 0.2077139
10 0.1702621 0.6902905 0.2777238 0.1302889 0.3070859
[[9]]$`x5>0.418`
x1 x2 x3 x4 x5
2 0.25767250 0.8821655 0.7108038 0.9285051 0.8651205
3 0.55232243 0.2803538 0.5383487 0.3486920 0.7775844
4 0.05638315 0.3984879 0.7489722 0.9541577 0.8273034
5 0.46854928 0.7625511 0.4201015 0.6952741 0.6033244
6 0.48377074 0.6690217 0.1714202 0.8894535 0.4912318
7 0.81240262 0.2046122 0.7703016 0.1804072 0.7803585
8 0.37032054 0.3575249 0.8819536 0.6293909 0.8842270
[[10]]
[[10]]$`x5<=0.418,x4<=0.234`
x1 x2 x3 x4 x5
10 0.1702621 0.6902905 0.2777238 0.1302889 0.3070859
[[10]]$`x5<=0.418,x4>0.234`
x1 x2 x3 x4 x5
1 0.3077661 0.6249965 0.5358112 0.4883060 0.3306605
9 0.5465586 0.3594751 0.5490967 0.9895641 0.2077139
[[11]]
[[11]]$`x5<=0.418,x4<=0.234`
x1 x2 x3 x4 x5
10 0.1702621 0.6902905 0.2777238 0.1302889 0.3070859
[[12]]
[[12]]$`x5<=0.418,x4>0.234`
x1 x2 x3 x4 x5
1 0.3077661 0.6249965 0.5358112 0.4883060 0.3306605
9 0.5465586 0.3594751 0.5490967 0.9895641 0.2077139
[[13]]
[[13]]$`x5>0.418,x3<=0.747`
x1 x2 x3 x4 x5
2 0.2576725 0.8821655 0.7108038 0.9285051 0.8651205
3 0.5523224 0.2803538 0.5383487 0.3486920 0.7775844
5 0.4685493 0.7625511 0.4201015 0.6952741 0.6033244
6 0.4837707 0.6690217 0.1714202 0.8894535 0.4912318
[[13]]$`x5>0.418,x3>0.747`
x1 x2 x3 x4 x5
4 0.05638315 0.3984879 0.7489722 0.9541577 0.8273034
7 0.81240262 0.2046122 0.7703016 0.1804072 0.7803585
8 0.37032054 0.3575249 0.8819536 0.6293909 0.8842270
[[14]]
[[14]]$`x5>0.418,x3<=0.747`
x1 x2 x3 x4 x5
2 0.2576725 0.8821655 0.7108038 0.9285051 0.8651205
3 0.5523224 0.2803538 0.5383487 0.3486920 0.7775844
5 0.4685493 0.7625511 0.4201015 0.6952741 0.6033244
6 0.4837707 0.6690217 0.1714202 0.8894535 0.4912318
[[15]]
[[15]]$`x5>0.418,x3>0.747`
x1 x2 x3 x4 x5
4 0.05638315 0.3984879 0.7489722 0.9541577 0.8273034
7 0.81240262 0.2046122 0.7703016 0.1804072 0.7803585
8 0.37032054 0.3575249 0.8819536 0.6293909 0.8842270
类似的问题已经被问到,例如 here and here 但 none 其他问题可以应用于我的问题。我试图确定和计算决策树中每个节点中的观察值。但是,树结构来自我从 BART
包中创建的树数据框。我正在从 BART
包中提取树信息并将其转换为类似于下图所示的数据框(即 df
)。但我需要使用提供的数据框架结构。旁白:我相信我使用的方法与我的数据框中的树 drawn/ordered 有关,称为 'depth first'.
例如,我的树木数据框如下所示:
library(dplyr)
df <- tibble(variableName = c("x2", "x1", NA, NA, NA, "x2", NA, NA, "x5", "x4", NA, NA, "x3", NA, NA),
splitValue = c(0.542, 0.126, NA, NA, NA, 0.6547, NA, NA, 0.418, 0.234, NA, NA, 0.747, NA, NA),
treeNo = c(1,1,1,1,1,2,2,2,3,3,3,3,3,3,3))
从视觉上看,这些树看起来像:
向下遍历df
时正在绘制树left-first。此外,所有拆分都是二进制拆分。所以每个节点都会有2个children.
因此,如果我们创建一些如下所示的数据:
set.seed(100)
dat <- data.frame( x1 = runif(10),
x2 = runif(10),
x3 = runif(10),
x4 = runif(10),
x5 = runif(10)
)
我正在尝试找出 dat
的哪些观察结果属于哪个节点?
尝试回答: 这并不是很有帮助,但为了清楚起见(因为我仍在努力解决这个问题),为三号树硬编码它看起来像这样:
lists <- df %>% group_by(treeNo) %>% group_split()
tree<- lists[[3]]
namesDf <- names(dat[grepl(tree[1, ]$variableName, names(dat))])
dataLeft <- dat[dat[, namesDf] <= tree[1,]$splitValue, ]
dataRight <- dat[dat[, namesDf] > tree[1,]$splitValue, ]
namesDf <- names(dat[grepl(tree[2, ]$variableName, names(dat))])
dataLeft1 <- dataLeft[dataLeft[, namesDf] <= tree[2,]$splitValue, ]
dataRight1 <- dataLeft[dataLeft[, namesDf] > tree[2,]$splitValue, ]
namesDf <- names(dat[grepl(tree[5, ]$variableName, names(dat))])
dataLeft2 <- dataRight[dataRight[, namesDf] <= tree[5,]$splitValue, ]
dataRight2 <- dataRight[dataRight[, namesDf] > tree[5,]$splitValue, ]
我一直在尝试将其变成一个循环。但事实证明,锻炼起来很有挑战性。 而且我(显然)不能为每棵树硬编码。关于如何解决这个问题有什么建议吗??
还有很大的优化空间,不过这是我的尝试。您的树似乎以 depth-first 方式构建,左侧 children 始终跟随 parent 节点:
library(dplyr)
df <- tibble(variableName = c("x2", "x1", NA, NA, NA, "x2", NA, NA, "x5", "x4", NA, NA, "x3", NA, NA),
splitValue = c(0.542, 0.126, NA, NA, NA, 0.6547, NA, NA, 0.418, 0.234, NA, NA, 0.747, NA, NA),
treeNo = c(1,1,1,1,1,2,2,2,3,3,3,3,3,3,3))
给定要匹配的数据:
set.seed(100)
dat <- data.frame( x1 = runif(10),
x2 = runif(10),
x3 = runif(10),
x4 = runif(10),
x5 = runif(10)
)
dat
##> x1 x2 x3 x4 x5
##>1 0.30776611 0.6249965 0.5358112 0.4883060 0.3306605
##>2 0.25767250 0.8821655 0.7108038 0.9285051 0.8651205
##>3 0.55232243 0.2803538 0.5383487 0.3486920 0.7775844
##>4 0.05638315 0.3984879 0.7489722 0.9541577 0.8273034
##>5 0.46854928 0.7625511 0.4201015 0.6952741 0.6033244
##>6 0.48377074 0.6690217 0.1714202 0.8894535 0.4912318
##>7 0.81240262 0.2046122 0.7703016 0.1804072 0.7803585
##>8 0.37032054 0.3575249 0.8819536 0.6293909 0.8842270
##>9 0.54655860 0.3594751 0.5490967 0.9895641 0.2077139
##>10 0.17026205 0.6902905 0.2777238 0.1302889 0.3070859
makeTree
是一个 higher-order 函数,returns 一个函数反过来将一行值映射到一个节点:
makeTree <- function(dat, r = 1) {
## the argument dat is a dataframe representation
## of a single tree as in the example
## return a list of two elements: size and fn.
## - size is the number of cells taken by the
## node and its descendants.
## - fn is a function of one argument (either a list or
## a row of a dataframe) that returns the index of the
## node matching argument. More precisely the column Id
## in dat.
stopifnot(r <= nrow(dat))
vname <- pull(dat,variableName)[r]
splitVal <- pull(dat, splitValue)[r]
if (is.na(vname)) {
## terminal node
## print(sprintf("terminal node: %i", r))
res <- list(size = 1, # offset to access right node
fn = function(z) {
pull(dat, "id")[r]
})
return(res)
} else {
##print(sprintf("node: %i, varName: %s, splitVal: %f", r, vname, splitVal ))
## compute the left and right functions
## note that the tree is traversed depth-first
fnleft <- makeTree(dat, r + 1) #fnleft is always positoned next to the
#caller
fnright <- makeTree(dat, r + fnleft$size + 1 )
return(list(size = fnleft$size + fnright$size + 1,
fn = function(z) {
if (z[vname] <= splitVal)
fnleft$fn(z)
else
fnright$fn(z)
}))
}
}
现在 makeTree
应用于每棵树以生成匹配函数列表:
treefns <- df |>
mutate(id = row_number()) %>%
group_by(treeNo) |>
group_split() |>
purrr::map(makeTree) |>
purrr::map("fn")
最后,数据帧的每一行 dat
都与树的一个节点相匹配:
apply(dat,1, function(z) sapply(treefns, function(fn) fn(z))) |>
t() |>
data.frame() |>
rename_with(function(z) paste0("TREE", gsub("X", "", z))) |>
cbind(dat) |>
pivot_longer(cols = starts_with("TREE"),
names_to = "TREE",
values_to = "NODE") |>
sample_n(10)
##> A tibble: 10 x 7
##> x1 x2 x3 x4 x5 TREE NODE
##> <dbl> <dbl> <dbl> <dbl> <dbl> <chr> <int>
##> 1 0.170 0.690 0.278 0.130 0.307 TREE3 11
##> 2 0.170 0.690 0.278 0.130 0.307 TREE2 8
##> 3 0.370 0.358 0.882 0.629 0.884 TREE2 7
##> 4 0.308 0.625 0.536 0.488 0.331 TREE1 5
##> 5 0.370 0.358 0.882 0.629 0.884 TREE1 4
##> 6 0.552 0.280 0.538 0.349 0.778 TREE3 14
##> 7 0.547 0.359 0.549 0.990 0.208 TREE1 4
##> 8 0.370 0.358 0.882 0.629 0.884 TREE3 15
##> 9 0.547 0.359 0.549 0.990 0.208 TREE2 7
##>10 0.0564 0.398 0.749 0.954 0.827 TREE2 7
看来我们可以做“滚动拆分”来得到你要找的东西。逻辑如下
- 从只有一个数据帧的堆栈开始
dat
。 - 对于每对
variableName
和splitValue
,如果它们不是NA
,则将该堆栈上的顶部数据帧拆分为两个由variableName <= splitValue
标识的子数据帧和variableName > splitValue
(前者在后者之上);如果它们是NA
s,那么只需弹出顶部数据框。
这是代码。请注意,这种 state-dependent 计算很难矢量化。因此,这不是 R 擅长的。如果你有很多树并且代码性能成为一个严重的问题,我建议使用 Rcpp
.
eval_node <- function(df, x, v) {
out <- vector("list", length(x))
stk <- vector("list", sum(is.na(x)))
pos <- 1L
stk[[pos]] <- df
for (i in seq_along(x)) {
if (!is.na(x[[i]])) {
subs <- pos + c(0L, 1L)
stk[subs] <- split(stk[[pos]], stk[[pos]][[x[[i]]]] <= v[[i]])
names(stk)[subs] <- trimws(paste0(
names(stk[pos]), ",", x[[i]], c(">", "<="), v[[i]]
), "left", ",")
out[[i]] <- rev(stk[subs])
pos <- pos + 1L
} else {
out[[i]] <- stk[pos]
stk[[pos]] <- NULL
pos <- pos - 1L
}
}
out
}
然后你可以像这样应用函数。
library(dplyr)
df %>% group_by(treeNo) %>% mutate(node = eval_node(dat, variableName, splitValue))
输出
# A tibble: 15 x 4
# Groups: treeNo [3]
variableName splitValue treeNo node
<chr> <dbl> <dbl> <list>
1 x2 0.542 1 <named list [2]>
2 x1 0.126 1 <named list [2]>
3 NA NA 1 <named list [1]>
4 NA NA 1 <named list [1]>
5 NA NA 1 <named list [1]>
6 x2 0.655 2 <named list [2]>
7 NA NA 2 <named list [1]>
8 NA NA 2 <named list [1]>
9 x5 0.418 3 <named list [2]>
10 x4 0.234 3 <named list [2]>
11 NA NA 3 <named list [1]>
12 NA NA 3 <named list [1]>
13 x3 0.747 3 <named list [2]>
14 NA NA 3 <named list [1]>
15 NA NA 3 <named list [1]>
,其中 node
看起来像这样
[[1]]
[[1]]$`x2<=0.542`
x1 x2 x3 x4 x5
3 0.55232243 0.2803538 0.5383487 0.3486920 0.7775844
4 0.05638315 0.3984879 0.7489722 0.9541577 0.8273034
7 0.81240262 0.2046122 0.7703016 0.1804072 0.7803585
8 0.37032054 0.3575249 0.8819536 0.6293909 0.8842270
9 0.54655860 0.3594751 0.5490967 0.9895641 0.2077139
[[1]]$`x2>0.542`
x1 x2 x3 x4 x5
1 0.3077661 0.6249965 0.5358112 0.4883060 0.3306605
2 0.2576725 0.8821655 0.7108038 0.9285051 0.8651205
5 0.4685493 0.7625511 0.4201015 0.6952741 0.6033244
6 0.4837707 0.6690217 0.1714202 0.8894535 0.4912318
10 0.1702621 0.6902905 0.2777238 0.1302889 0.3070859
[[2]]
[[2]]$`x2<=0.542,x1<=0.126`
x1 x2 x3 x4 x5
4 0.05638315 0.3984879 0.7489722 0.9541577 0.8273034
[[2]]$`x2<=0.542,x1>0.126`
x1 x2 x3 x4 x5
3 0.5523224 0.2803538 0.5383487 0.3486920 0.7775844
7 0.8124026 0.2046122 0.7703016 0.1804072 0.7803585
8 0.3703205 0.3575249 0.8819536 0.6293909 0.8842270
9 0.5465586 0.3594751 0.5490967 0.9895641 0.2077139
[[3]]
[[3]]$`x2<=0.542,x1<=0.126`
x1 x2 x3 x4 x5
4 0.05638315 0.3984879 0.7489722 0.9541577 0.8273034
[[4]]
[[4]]$`x2<=0.542,x1>0.126`
x1 x2 x3 x4 x5
3 0.5523224 0.2803538 0.5383487 0.3486920 0.7775844
7 0.8124026 0.2046122 0.7703016 0.1804072 0.7803585
8 0.3703205 0.3575249 0.8819536 0.6293909 0.8842270
9 0.5465586 0.3594751 0.5490967 0.9895641 0.2077139
[[5]]
[[5]]$`x2>0.542`
x1 x2 x3 x4 x5
1 0.3077661 0.6249965 0.5358112 0.4883060 0.3306605
2 0.2576725 0.8821655 0.7108038 0.9285051 0.8651205
5 0.4685493 0.7625511 0.4201015 0.6952741 0.6033244
6 0.4837707 0.6690217 0.1714202 0.8894535 0.4912318
10 0.1702621 0.6902905 0.2777238 0.1302889 0.3070859
[[6]]
[[6]]$`x2<=0.6547`
x1 x2 x3 x4 x5
1 0.30776611 0.6249965 0.5358112 0.4883060 0.3306605
3 0.55232243 0.2803538 0.5383487 0.3486920 0.7775844
4 0.05638315 0.3984879 0.7489722 0.9541577 0.8273034
7 0.81240262 0.2046122 0.7703016 0.1804072 0.7803585
8 0.37032054 0.3575249 0.8819536 0.6293909 0.8842270
9 0.54655860 0.3594751 0.5490967 0.9895641 0.2077139
[[6]]$`x2>0.6547`
x1 x2 x3 x4 x5
2 0.2576725 0.8821655 0.7108038 0.9285051 0.8651205
5 0.4685493 0.7625511 0.4201015 0.6952741 0.6033244
6 0.4837707 0.6690217 0.1714202 0.8894535 0.4912318
10 0.1702621 0.6902905 0.2777238 0.1302889 0.3070859
[[7]]
[[7]]$`x2<=0.6547`
x1 x2 x3 x4 x5
1 0.30776611 0.6249965 0.5358112 0.4883060 0.3306605
3 0.55232243 0.2803538 0.5383487 0.3486920 0.7775844
4 0.05638315 0.3984879 0.7489722 0.9541577 0.8273034
7 0.81240262 0.2046122 0.7703016 0.1804072 0.7803585
8 0.37032054 0.3575249 0.8819536 0.6293909 0.8842270
9 0.54655860 0.3594751 0.5490967 0.9895641 0.2077139
[[8]]
[[8]]$`x2>0.6547`
x1 x2 x3 x4 x5
2 0.2576725 0.8821655 0.7108038 0.9285051 0.8651205
5 0.4685493 0.7625511 0.4201015 0.6952741 0.6033244
6 0.4837707 0.6690217 0.1714202 0.8894535 0.4912318
10 0.1702621 0.6902905 0.2777238 0.1302889 0.3070859
[[9]]
[[9]]$`x5<=0.418`
x1 x2 x3 x4 x5
1 0.3077661 0.6249965 0.5358112 0.4883060 0.3306605
9 0.5465586 0.3594751 0.5490967 0.9895641 0.2077139
10 0.1702621 0.6902905 0.2777238 0.1302889 0.3070859
[[9]]$`x5>0.418`
x1 x2 x3 x4 x5
2 0.25767250 0.8821655 0.7108038 0.9285051 0.8651205
3 0.55232243 0.2803538 0.5383487 0.3486920 0.7775844
4 0.05638315 0.3984879 0.7489722 0.9541577 0.8273034
5 0.46854928 0.7625511 0.4201015 0.6952741 0.6033244
6 0.48377074 0.6690217 0.1714202 0.8894535 0.4912318
7 0.81240262 0.2046122 0.7703016 0.1804072 0.7803585
8 0.37032054 0.3575249 0.8819536 0.6293909 0.8842270
[[10]]
[[10]]$`x5<=0.418,x4<=0.234`
x1 x2 x3 x4 x5
10 0.1702621 0.6902905 0.2777238 0.1302889 0.3070859
[[10]]$`x5<=0.418,x4>0.234`
x1 x2 x3 x4 x5
1 0.3077661 0.6249965 0.5358112 0.4883060 0.3306605
9 0.5465586 0.3594751 0.5490967 0.9895641 0.2077139
[[11]]
[[11]]$`x5<=0.418,x4<=0.234`
x1 x2 x3 x4 x5
10 0.1702621 0.6902905 0.2777238 0.1302889 0.3070859
[[12]]
[[12]]$`x5<=0.418,x4>0.234`
x1 x2 x3 x4 x5
1 0.3077661 0.6249965 0.5358112 0.4883060 0.3306605
9 0.5465586 0.3594751 0.5490967 0.9895641 0.2077139
[[13]]
[[13]]$`x5>0.418,x3<=0.747`
x1 x2 x3 x4 x5
2 0.2576725 0.8821655 0.7108038 0.9285051 0.8651205
3 0.5523224 0.2803538 0.5383487 0.3486920 0.7775844
5 0.4685493 0.7625511 0.4201015 0.6952741 0.6033244
6 0.4837707 0.6690217 0.1714202 0.8894535 0.4912318
[[13]]$`x5>0.418,x3>0.747`
x1 x2 x3 x4 x5
4 0.05638315 0.3984879 0.7489722 0.9541577 0.8273034
7 0.81240262 0.2046122 0.7703016 0.1804072 0.7803585
8 0.37032054 0.3575249 0.8819536 0.6293909 0.8842270
[[14]]
[[14]]$`x5>0.418,x3<=0.747`
x1 x2 x3 x4 x5
2 0.2576725 0.8821655 0.7108038 0.9285051 0.8651205
3 0.5523224 0.2803538 0.5383487 0.3486920 0.7775844
5 0.4685493 0.7625511 0.4201015 0.6952741 0.6033244
6 0.4837707 0.6690217 0.1714202 0.8894535 0.4912318
[[15]]
[[15]]$`x5>0.418,x3>0.747`
x1 x2 x3 x4 x5
4 0.05638315 0.3984879 0.7489722 0.9541577 0.8273034
7 0.81240262 0.2046122 0.7703016 0.1804072 0.7803585
8 0.37032054 0.3575249 0.8819536 0.6293909 0.8842270