R partykit计算子树的分类概率

R partykit calculate classification probabilities on sub stree

我训练了一个partykit包ctree分类决策树,我需要计算子树(不仅是叶节点)的分类概率。 因此,例如,如果一棵子树由 3 个具有以下概率的叶节点组成: 叶 1(120 个观察值):0.45 叶子 2(160 个观察值):0.49 叶 3(190 个观察值):0.83

对于这个假设的子树,加权平均概率为 120*0.42 + 160*0.49 + 190*0.83 / (120+160+190) = 0.507

等等我需要遍历ctree对象,递归计算每个节点的所有加权概率。

我有这个代码:

data(airquality)
airq <- subset(airquality, !is.na(Ozone))
airct <- ctree(Ozone ~ ., data = airq,
                 controls = ctree_control(maxsurrogate = 3))
traverse <- function(treenode){
    if(treenode$terminal){
      bas=paste("Current node is terminal node with",treenode$nodeID,'prediction',treenode$prediction)
      print(bas)
      return(0)
    } else {
      bas=paste("Current node",treenode$nodeID,"Split var. ID:",treenode$psplit$variableName,"split value:",treenode$psplit$splitpoint,'prediction',treenode$prediction)
      print(bas)
    }
    traverse(treenode$left)
    traverse(treenode$right)
  }

树上的遍历不适用于 partykit 对象。 另一方面,我有这段代码,它只列出了叶节点的所有概率:

preds.ls <- list(predict(airct , type = "prob"))[1]
pred.probs.df <- unique(as.data.frame((preds.ls[[1]])))

任何将这 2 个片段组合成将遍历 PARTYKIT 对象并计算此加权平均值的代码的建议都值得赞赏

我不熟悉 partykit 但这个简单的函数遍历 ctree 并提取每个内部和终端节点的概率:

   library(party)

    set.seed(100)
    dt <- ctree(factor(mpg > 20)~., data = mtcars,
                control = ctree_control(minsplit=2, minbucket=1, mincriterion=0))

    traverse <- function(node) {
      if (node$terminal) {
        return(node$prediction[2])
      }
      return(c(node$prediction[2],
               traverse(node$left), traverse(node$right)))
    }

调用该函数会产生以下概率向量:

> traverse(dt@tree)
[1] 0.4375000 1.0000000 0.1428571 0.4285714 0.7500000 0.0000000 0.0000000

最左边的值是通过以下验证的人口值:

> mean(mtcars$mpg > 20)
[1] 0.4375

其余值将按从左到右的顺序排列。您可以看到 1 和 0 在预期的位置排列。