修改ctree()中的终端节点,partykit包
Modifying terminal node in ctree(), partykit package
我有一个因变量要通过决策树class确定。它由三类频率组成:738 (19%)、426 (15%) 和 1800 (66%)。正如您想象的那样,预测类别始终是第三个类别,但树的目的是描述性的,因此实际上并不重要。
问题是,当通过 ctree()
函数(包 partykit
)绘制树时,终端节点显示直方图,显示三个 classes 出现的概率。我需要修改此输出:我想获得终端节点内每个 class 相对于 class' 绝对频率的出现比例。
比如class1的738个参与者中有多少百分比属于某个终端节点?每个终端节点将显示构成因变量的所有三个 class 的值。
下面是树图,默认情况下会报告终端节点内每个 class 的流行程度。
您始终可以定义自己的面板函数来绘制进入每个终端面板的内容 window。如果您对 grid
图形了解一点,并且查看当前终端面板功能的定义方式,您就会明白它是如何工作的。
partykit
包中的 node_terminal()
一个面板功能应该可以满足您的需求(对旧的 party
包进行了很大改进的重新实现)。但是,由于 ctree()
不会将其预测存储在每个终端节点中,因此 node_terminal()
函数目前无法开箱即用。我将尝试改进未来版本中的实现,以便于实现这一点。下面是一个有点复杂的例子,我希望它能做你想做的事。
首先,我们使用 iris
数据拟合分类树(对于一个简单的可重现示例):
library("partykit")
(ct <- ctree(Species ~ ., data = iris))
## Model formula:
## Species ~ Sepal.Length + Sepal.Width + Petal.Length + Petal.Width
##
## Fitted party:
## [1] root
## | [2] Petal.Length <= 1.9: setosa (n = 50, err = 0.0%)
## | [3] Petal.Length > 1.9
## | | [4] Petal.Width <= 1.7
## | | | [5] Petal.Length <= 4.8: versicolor (n = 46, err = 2.2%)
## | | | [6] Petal.Length > 4.8: versicolor (n = 8, err = 50.0%)
## | | [7] Petal.Width > 1.7: virginica (n = 46, err = 2.2%)
##
## Number of inner nodes: 3
## Number of terminal nodes: 4
然后我们计算每个终端节点的 table 预测概率:
(pred <- aggregate(predict(ct, type = "prob"),
list(predict(ct, type = "node")), FUN = mean))
## Group.1 setosa versicolor virginica
## 1 2 1 0.00000000 0.00000000
## 2 5 0 0.97826087 0.02173913
## 3 6 0 0.50000000 0.50000000
## 4 7 0 0.02173913 0.97826087
然后是不太明显的部分:我们希望将这些预测概率包含在树本身的终端节点中。为此,我们将递归节点结构强制转换为平面列表,插入预测(适当格式化),并将列表转换回节点结构:
ct_node <- as.list(ct$node)
for(i in 1:nrow(pred)) {
ct_node[[pred[i,1]]]$info$prediction <- paste(
format(names(pred)[-1]),
format(round(pred[i, -1], digits = 3), nsmall = 3)
)
}
ct$node <- as.partynode(ct_node)
然后,我们可以使用 node_terminal
面板函数轻松绘制树的图片并插入我们预先格式化的预测:
plot(ct, terminal_panel = node_terminal, tp_args = list(
FUN = function(node) c("Predictions", node$prediction)))
编辑:list
和 party
之间的来回强制实际上已经在包中实现了……我只是忘了它 ;-) 如果你这样做
st <- as.simpleparty(ct)
然后结果 party
在每个节点中都有关于预测等的更详细信息。例如, $distribution
然后包含每个响应级别的绝对频率。这可以很容易地像以前一样格式化
pred <- function(i) {
tab <- i$distribution
tab <- round(prop.table(tab), 3)
tab <- paste0(names(tab), ":", format(tab, nsmall = 3))
c("Predictions", tab)
}
这可以传递给 node_terminal
以创建上面的情节。如果您希望所有终端节点都显示在底行,您可能需要将 drop = FALSE
更改为 drop = TRUE
。
plot(st, terminal_panel = node_terminal, tp_args = list(FUN = pred))
我有一个因变量要通过决策树class确定。它由三类频率组成:738 (19%)、426 (15%) 和 1800 (66%)。正如您想象的那样,预测类别始终是第三个类别,但树的目的是描述性的,因此实际上并不重要。
问题是,当通过 ctree()
函数(包 partykit
)绘制树时,终端节点显示直方图,显示三个 classes 出现的概率。我需要修改此输出:我想获得终端节点内每个 class 相对于 class' 绝对频率的出现比例。
比如class1的738个参与者中有多少百分比属于某个终端节点?每个终端节点将显示构成因变量的所有三个 class 的值。
下面是树图,默认情况下会报告终端节点内每个 class 的流行程度。
您始终可以定义自己的面板函数来绘制进入每个终端面板的内容 window。如果您对 grid
图形了解一点,并且查看当前终端面板功能的定义方式,您就会明白它是如何工作的。
partykit
包中的 node_terminal()
一个面板功能应该可以满足您的需求(对旧的 party
包进行了很大改进的重新实现)。但是,由于 ctree()
不会将其预测存储在每个终端节点中,因此 node_terminal()
函数目前无法开箱即用。我将尝试改进未来版本中的实现,以便于实现这一点。下面是一个有点复杂的例子,我希望它能做你想做的事。
首先,我们使用 iris
数据拟合分类树(对于一个简单的可重现示例):
library("partykit")
(ct <- ctree(Species ~ ., data = iris))
## Model formula:
## Species ~ Sepal.Length + Sepal.Width + Petal.Length + Petal.Width
##
## Fitted party:
## [1] root
## | [2] Petal.Length <= 1.9: setosa (n = 50, err = 0.0%)
## | [3] Petal.Length > 1.9
## | | [4] Petal.Width <= 1.7
## | | | [5] Petal.Length <= 4.8: versicolor (n = 46, err = 2.2%)
## | | | [6] Petal.Length > 4.8: versicolor (n = 8, err = 50.0%)
## | | [7] Petal.Width > 1.7: virginica (n = 46, err = 2.2%)
##
## Number of inner nodes: 3
## Number of terminal nodes: 4
然后我们计算每个终端节点的 table 预测概率:
(pred <- aggregate(predict(ct, type = "prob"),
list(predict(ct, type = "node")), FUN = mean))
## Group.1 setosa versicolor virginica
## 1 2 1 0.00000000 0.00000000
## 2 5 0 0.97826087 0.02173913
## 3 6 0 0.50000000 0.50000000
## 4 7 0 0.02173913 0.97826087
然后是不太明显的部分:我们希望将这些预测概率包含在树本身的终端节点中。为此,我们将递归节点结构强制转换为平面列表,插入预测(适当格式化),并将列表转换回节点结构:
ct_node <- as.list(ct$node)
for(i in 1:nrow(pred)) {
ct_node[[pred[i,1]]]$info$prediction <- paste(
format(names(pred)[-1]),
format(round(pred[i, -1], digits = 3), nsmall = 3)
)
}
ct$node <- as.partynode(ct_node)
然后,我们可以使用 node_terminal
面板函数轻松绘制树的图片并插入我们预先格式化的预测:
plot(ct, terminal_panel = node_terminal, tp_args = list(
FUN = function(node) c("Predictions", node$prediction)))
编辑:list
和 party
之间的来回强制实际上已经在包中实现了……我只是忘了它 ;-) 如果你这样做
st <- as.simpleparty(ct)
然后结果 party
在每个节点中都有关于预测等的更详细信息。例如, $distribution
然后包含每个响应级别的绝对频率。这可以很容易地像以前一样格式化
pred <- function(i) {
tab <- i$distribution
tab <- round(prop.table(tab), 3)
tab <- paste0(names(tab), ":", format(tab, nsmall = 3))
c("Predictions", tab)
}
这可以传递给 node_terminal
以创建上面的情节。如果您希望所有终端节点都显示在底行,您可能需要将 drop = FALSE
更改为 drop = TRUE
。
plot(st, terminal_panel = node_terminal, tp_args = list(FUN = pred))