从决策树进行预测的高效算法(使用 R)

Efficient algorithm for predicting from a decision tree (using R)

我正在修改 Brieman 的随机森林程序(我不知道 C/C++),所以我在 R 中从头开始编写了我自己的 RF 变体。两者之间的区别我的程序和标准程序基本上只是如何计算分割点和终端节点中的值——一旦我在森林中有一棵树,它可以被认为与典型 RF 算法中的树非常相似。

我的问题是它的预测速度很慢,而且我很难想出让它更快的方法。

链接了一个测试树对象here, and some test data is linked here。可以直接下载,也可以在安装了repmis的情况下在下方加载。他们被称为 testtreesampx.

library(repmis)
testtree <- source_DropboxData(file = "testtree", key = "sfbmojc394cnae8")
sampx <- source_DropboxData(file = "sampx", key = "r9imf317hpflpsx")

编辑:不知何故,我仍然没有时间真正学习如何很好地使用 github。我已经将需要的文件上传到存储库 here -- 抱歉,我现在不知道如何获得永久链接...

它看起来像这样(使用我编写的绘图函数):

以下是有关对象结构的一些信息:

1> summary(testtree)
         Length Class      Mode   
nodes       7   -none-     list   
minsplit    1   -none-     numeric
X          29   data.frame list   
y        6719   -none-     numeric
weights  6719   -none-     numeric
oob      2158   -none-     numeric
1> summary(testtree$nodes)
     Length Class  Mode
[1,] 4      -none- list
[2,] 8      -none- list
[3,] 8      -none- list
[4,] 7      -none- list
[5,] 7      -none- list
[6,] 7      -none- list
[7,] 7      -none- list
1> summary(testtree$nodes[[1]])
         Length Class  Mode   
y        6719   -none- numeric
output         1   -none- numeric
Terminal    1   -none- logical
children    2   -none- numeric
1> testtree$nodes[[1]][2:4]
$output
[1] 40.66925

$Terminal
[1] FALSE

$children
[1] 2 3

1> summary(testtree$nodes[[2]])
           Length Class  Mode     
y          2182   -none- numeric  
parent        1   -none- numeric  
splitvar      1   -none- character
splitpoint    1   -none- numeric  
handedness    1   -none- character
children      2   -none- numeric  
output        1   -none- numeric  
Terminal      1   -none- logical  
1> testtree$nodes[[2]][2:8]
$parent
[1] 1

$splitvar
[1] "bizrev_allHH"

$splitpoint
    25% 
788.875 

$handedness
[1] "Left"

$children
[1] 4 5

$output
[1] 287.0085

$Terminal
[1] FALSE

output 是该节点的 return 值——我希望其他一切都是不言自明的。

我写的预测函数可以用,但是太慢了。基本上就是"walks down the tree",观察观察:

predict.NT = function(tree.obj, newdata=NULL){
    if (is.null(newdata)){X = tree.obj$X} else {X = newdata}
    tree = tree.obj$nodes
    if (length(tree)==1){#Return the mean for a stump
        return(rep(tree[[1]]$output,length(X)))
    }
    pred = apply(X = newdata, 1, godowntree, nn=1, tree=tree)
    return(pred)
}

godowntree = function(x, tree, nn = 1){
    while (tree[[nn]]$Terminal == FALSE){
        fb = tree[[nn]]$children[1]
        sv = tree[[fb]]$splitvar
        sp = tree[[fb]]$splitpoint
        if (class(sp)=='factor'){
            if (as.character(x[names(x) == sv]) == sp){
                nn<-fb
            } else{
                nn<-fb+1
            }
        } else {
            if (as.character(x[names(x) == sv]) < sp){
                nn<-fb
            } else{
                nn<-fb+1
            }
        }
    }
    return(tree[[nn]]$output)
}

问题是它真的很慢(当你考虑到非样本树更大,而且我需要这样做很多很多次),即使是一个简单的树:

library(microbenchmark)
microbenchmark(predict.NT(testtree,sampx))
Unit: milliseconds
                        expr      min       lq     mean   median       uq
 predict.NT(testtree, sampx) 16.19845 16.36351 17.37022 16.54396 17.07274
     max neval
 40.4691   100

我今天从某人那里得到了一个想法,我可以编写一个函数工厂类型的函数(即:生成闭包的函数,我刚刚学习)将我的树分解成一堆嵌套的 if/else 语句。然后我可以通过它发送数据,这可能比一遍又一遍地从树中提取数据更快。我还没有写函数函数生成函数,但我手写了我从中得到的那种输出,并测试了:

predictif = function(x){
    if (x[names(x) == 'bizrev_allHH'] < 788.875){
        if (x[names(x) == 'male_head'] <.872){
            return(548)
        } else {
            return(165)
        }
    } else {
        if (x[names(x) == 'nondurable_exp_mo'] < 4190.965){
            return(-283)
        }else{
            return(-11.4)
        }
    }
}
predictif.NT = function(tree.obj, newdata=NULL){
    if (is.null(newdata)){X = tree.obj$X} else {X = newdata}
    tree = tree.obj$nodes
    if (length(tree)==1){#Return the mean for a stump
        return(rep(tree[[1]]$output,length(X)))
    }
    pred = apply(X = newdata, 1, predictif)
    return(pred)
}

microbenchmark(predictif.NT(testtree,sampx))
Unit: milliseconds
                          expr      min       lq     mean   median       uq
 predictif.CT(testtree, sampx) 12.77701 12.97551 14.21417 13.18939 13.67667
      max neval
 30.48373   100

快一点,但不多!

如果有任何加快速度的想法,我将不胜感激!或者,如果答案是 "you really can't get this much faster without converting it to C/C++",那也将是有价值的信息(特别是如果您向我提供了一些关于为什么会这样的信息)。

虽然我当然很喜欢 R 中的答案,但伪代码中的答案也很有帮助。

谢谢!

加速函数的秘诀是向量化。不要对每一行单独执行所有操作,而是一次对所有行执行它们。

让我们重新考虑一下您的 predictif 函数

predictif = function(x){
    if (x[names(x) == 'bizrev_allHH'] < 788.875){
        if (x[names(x) == 'male_head'] <.872){
            return(548)
        } else {
            return(165)
        }
    } else {
        if (x[names(x) == 'nondurable_exp_mo'] < 4190.965){
            return(-283)
        }else{
            return(-11.4)
        }
    }
}

这是一种缓慢的方法,因为它将所有这些操作应用于每个单独的实例。函数调用、if 语句,尤其是像 names(x) == 'bizrev_allHH' 这样的操作,在您为每个实例执行这些操作时都会增加一些开销。

相比之下,简单地比较两个数字非常快!因此,改为编写上面的矢量化版本。

predictif_fast <- function(newdata) {
  n1 <- newdata$bizrev_allHH < 788.875
  n2 <- newdata$male_head < .872
  n3 <- newdata$nondurable_exp_mo < 4190.965

  ifelse(n1, ifelse(n2, 548.55893, 165.15537),
             ifelse(n3, -283.35145, -11.40185))
}

注意,这一点非常重要,这个函数没有被传递一个实例。它旨在传递你的整个新数据。这是可行的,因为 <ifelse 操作都是向量化的:当给定一个向量时,它们 return 一个向量。

让我们比较一下你的函数和这个新函数:

> microbenchmark(predictif.NT(testtree, sampx),
                 predictif_fast(sampx))
Unit: microseconds
                          expr       min         lq     mean    median         uq
 predictif.NT(testtree, sampx) 12106.419 13144.2390 14684.46 13719.406 14593.1565
         predictif_fast(sampx)   189.093   213.6505   263.74   246.192   260.7895
       max neval cld
 79136.335   100   b
  2344.059   100  a 

注意我们通过矢量化获得了 50 倍的加速。

顺便说一句,可以大大加快速度(如果您对索引很聪明,ifelse 有更快的替代方案),但总体上从 "perform a function on each row" 切换到 "perform operations on entire vectors"让你获得最大的加速。


这并不能完全解决您的问题,因为您需要在一般树上执行这些矢量化操作,而不仅仅是在这个特定的树上。我不会为您解决通用版本,但考虑到您可以重写 godowntree 函数,以便它获取整个数据帧并对完整数据帧执行操作,而不仅仅是一个数据帧。然后,不使用 if 分支,而是保留每个实例当前所在子实例的向量。