在 R 的决策树中获取遵循特定规则的元组

To get the tuples which follow a particular rule in a decision tree in R

我正在使用 R 中的 rpart 创建决策树。我还可以使用 path.rpart() 函数打印出决策树生成的规则。

对于空气质量数据,我将规则输出为

$`8`
[1] "root"          "Temp< 82.5"    "Wind>=7.15"    "Solar.R< 79.5"

$`18`
[1] "root"          "Temp< 82.5"    "Wind>=7.15"    "Solar.R>=79.5"
[5] "Temp< 77.5"   

$`19`
[1] "root"          "Temp< 82.5"    "Wind>=7.15"    "Solar.R>=79.5"
5] "Temp>=77.5"   

$`5`
[1] "root"       "Temp< 82.5" "Wind< 7.15"

等等。

有没有一种方法可以让我编写代码,将这些限制条件施加到我最初的 table 空气质量上,以获得遵循这些规则的行 这相当于

airquality[which(airquality$Temp<82.5 & airquality$Wind>=7.15 & Solar.R<79.5)]

第一条规则。

非常感谢任何帮助。提前致谢。

rules = rpart(airquality)    
table(rules$where)
airquality[rules$where==6,]

你会在没有编码规则的情况下给你分割数据框吗?我不确定这是否是您要找的。

path.rpart 给出了一个很好的概述,但 MrFlick 已经编写了一些代码来显示哪些观察结果落在特定节点中。看here.

这只看rpart树。要查看预测值落在哪个节点,请查看 this post

参见我包含的示例代码。该功能来自第一个答案。第二个答案的最后一部分。

library(rpart)

# split kyphosis into 2 for example
train <- kyphosis[1:60, ]
test <- kyphosis[-(1:60), ]
fit <- rpart(Kyphosis ~ Age + Number + Start, data = train)

# Show nodes
print(fit)

# function to show observations that fall in a node
# 
subset_rpart <- function (tree, df, nodes) {
      if (!inherits(tree, "rpart")) 
            stop("Not a legitimate \"rpart\" object")
      stopifnot(nrow(df)==length(tree$where))
      frame <- tree$frame
      n <- row.names(frame)
      node <- as.numeric(n)

      if (missing(nodes)) {
            xy <- rpart:::rpartco(tree)
            i <- identify(xy, n = 1L, plot = FALSE)
            if(i> 0L) {
                  return( df[tree$where==i, ] )
            } else {
                  return(df[0,])
            }
      }
      else {
            if (length(nodes <- rpart:::node.match(nodes, node)) == 0L) 
                  return(df[0,])
            return ( df[tree$where %in% as.numeric(nodes), ] )
      }
}

subset_rpart(fit, train, 7)


# Find the nodes in which the test observations fall 
# 
nodes_fit <- fit
nodes_fit$frame$yval <- as.numeric(rownames(nodes_fit$frame))
testnodes <- predict(nodes_fit, test, type="vector")
print(testnodes)