在 R 的决策树中获取遵循特定规则的元组
To get the tuples which follow a particular rule in a decision tree in R
我正在使用 R 中的 rpart 创建决策树。我还可以使用 path.rpart() 函数打印出决策树生成的规则。
对于空气质量数据,我将规则输出为
$`8`
[1] "root" "Temp< 82.5" "Wind>=7.15" "Solar.R< 79.5"
$`18`
[1] "root" "Temp< 82.5" "Wind>=7.15" "Solar.R>=79.5"
[5] "Temp< 77.5"
$`19`
[1] "root" "Temp< 82.5" "Wind>=7.15" "Solar.R>=79.5"
5] "Temp>=77.5"
$`5`
[1] "root" "Temp< 82.5" "Wind< 7.15"
等等。
有没有一种方法可以让我编写代码,将这些限制条件施加到我最初的 table 空气质量上,以获得遵循这些规则的行
这相当于
airquality[which(airquality$Temp<82.5 & airquality$Wind>=7.15 & Solar.R<79.5)]
第一条规则。
非常感谢任何帮助。提前致谢。
rules = rpart(airquality)
table(rules$where)
airquality[rules$where==6,]
你会在没有编码规则的情况下给你分割数据框吗?我不确定这是否是您要找的。
path.rpart 给出了一个很好的概述,但 MrFlick 已经编写了一些代码来显示哪些观察结果落在特定节点中。看here.
这只看rpart树。要查看预测值落在哪个节点,请查看 this post。
参见我包含的示例代码。该功能来自第一个答案。第二个答案的最后一部分。
library(rpart)
# split kyphosis into 2 for example
train <- kyphosis[1:60, ]
test <- kyphosis[-(1:60), ]
fit <- rpart(Kyphosis ~ Age + Number + Start, data = train)
# Show nodes
print(fit)
# function to show observations that fall in a node
#
subset_rpart <- function (tree, df, nodes) {
if (!inherits(tree, "rpart"))
stop("Not a legitimate \"rpart\" object")
stopifnot(nrow(df)==length(tree$where))
frame <- tree$frame
n <- row.names(frame)
node <- as.numeric(n)
if (missing(nodes)) {
xy <- rpart:::rpartco(tree)
i <- identify(xy, n = 1L, plot = FALSE)
if(i> 0L) {
return( df[tree$where==i, ] )
} else {
return(df[0,])
}
}
else {
if (length(nodes <- rpart:::node.match(nodes, node)) == 0L)
return(df[0,])
return ( df[tree$where %in% as.numeric(nodes), ] )
}
}
subset_rpart(fit, train, 7)
# Find the nodes in which the test observations fall
#
nodes_fit <- fit
nodes_fit$frame$yval <- as.numeric(rownames(nodes_fit$frame))
testnodes <- predict(nodes_fit, test, type="vector")
print(testnodes)
我正在使用 R 中的 rpart 创建决策树。我还可以使用 path.rpart() 函数打印出决策树生成的规则。
对于空气质量数据,我将规则输出为
$`8`
[1] "root" "Temp< 82.5" "Wind>=7.15" "Solar.R< 79.5"
$`18`
[1] "root" "Temp< 82.5" "Wind>=7.15" "Solar.R>=79.5"
[5] "Temp< 77.5"
$`19`
[1] "root" "Temp< 82.5" "Wind>=7.15" "Solar.R>=79.5"
5] "Temp>=77.5"
$`5`
[1] "root" "Temp< 82.5" "Wind< 7.15"
等等。
有没有一种方法可以让我编写代码,将这些限制条件施加到我最初的 table 空气质量上,以获得遵循这些规则的行 这相当于
airquality[which(airquality$Temp<82.5 & airquality$Wind>=7.15 & Solar.R<79.5)]
第一条规则。
非常感谢任何帮助。提前致谢。
rules = rpart(airquality)
table(rules$where)
airquality[rules$where==6,]
你会在没有编码规则的情况下给你分割数据框吗?我不确定这是否是您要找的。
path.rpart 给出了一个很好的概述,但 MrFlick 已经编写了一些代码来显示哪些观察结果落在特定节点中。看here.
这只看rpart树。要查看预测值落在哪个节点,请查看 this post。
参见我包含的示例代码。该功能来自第一个答案。第二个答案的最后一部分。
library(rpart)
# split kyphosis into 2 for example
train <- kyphosis[1:60, ]
test <- kyphosis[-(1:60), ]
fit <- rpart(Kyphosis ~ Age + Number + Start, data = train)
# Show nodes
print(fit)
# function to show observations that fall in a node
#
subset_rpart <- function (tree, df, nodes) {
if (!inherits(tree, "rpart"))
stop("Not a legitimate \"rpart\" object")
stopifnot(nrow(df)==length(tree$where))
frame <- tree$frame
n <- row.names(frame)
node <- as.numeric(n)
if (missing(nodes)) {
xy <- rpart:::rpartco(tree)
i <- identify(xy, n = 1L, plot = FALSE)
if(i> 0L) {
return( df[tree$where==i, ] )
} else {
return(df[0,])
}
}
else {
if (length(nodes <- rpart:::node.match(nodes, node)) == 0L)
return(df[0,])
return ( df[tree$where %in% as.numeric(nodes), ] )
}
}
subset_rpart(fit, train, 7)
# Find the nodes in which the test observations fall
#
nodes_fit <- fit
nodes_fit$frame$yval <- as.numeric(rownames(nodes_fit$frame))
testnodes <- predict(nodes_fit, test, type="vector")
print(testnodes)