使用极慢的循环在 R 中生成交互式部分依赖图

Generating interactive partial dependence plots in R using loop extremely slow

我正在尝试通过遍历数据集中的列来生成交互式部分依赖图。

一个可重现的例子:

library(pdp)
library(xgboost)
library(Matrix)
library(ggplot2)
library(plotly)

data(mtcars)
target <- mtcars$mpg
mtcars$mpg <- NULL

mtcars.sparse <- sparse.model.matrix(target~., mtcars)

fit <- xgboost(data=mtcars.sparse, label=target, nrounds=100)

for (i in seq_along(names(mtcars))){
  p1 <- pdp::partial(fit,
                     pred.var = names(mtcars)[i],
                     pred.grid = data.frame(unique(mtcars[names(mtcars)[i]])),
                     train = mtcars.sparse,
                     type = "regression",
                     cats = c("cyl", "vs", "am", "gear", "carb"),
                     plot = FALSE)
  p2 <- ggplot(aes_string(x = names(mtcars)[i] , y = "yhat"), data = p1) +
    geom_line(color = '#E51837', size = .6) +
    labs(title = paste("Partial Dependence plot of", names(mtcars)[i] , sep = " ")) +
    theme(text = element_text(color = "#444444", family = 'Helvetica Neue'),
          plot.title = element_text(size = 13, color = '#333333'))

  print(ggplotly(p2, tooltip = c("x", "y")))

}

我的真实数据集(约 22k 行,30 列)的绘图循环大约需要 2 小时。关于如何加快它的任何想法?

由于 R 中使用数据结构的方式,如果您不小心,for() 循环可能会非常慢。如果您想了解更多关于这背后的技术原因,请查看 Hadley Wickham 的 Advanced R

实际上,有两种主要方法可以加快您要执行的操作:优化 for() 循环,以及使用 apply() 函数族。虽然这两种方法都可以很好地工作,但 apply() 方法往往比优化编写的 for() 循环更快,所以我会坚持使用该解决方案。

apply方法:

plotFunction <- 
  function(x) {
    p1 <- pdp::partial(fit,
                       pred.var = x,
                       pred.grid = data.frame(unique(mtcars[x])),
                       train = mtcars.sparse,
                       type = "regression",
                       cats = c("cyl", "vs", "am", "gear", "carb"),
                       plot = FALSE)
    p2 <- ggplot(aes_string(x = x , y = "yhat"), data = p1) +
      geom_line(color = '#E51837', size = .6) +
      labs(title = paste("Partial Dependence plot of", x , sep = " ")) +
      theme(text = element_text(color = "#444444", family = 'Helvetica Neue'),
            plot.title = element_text(size = 13, color = '#333333'))
    return(p2)
  }


plot.list <- lapply(varNames, plotFunction)

system.time(lapply(varNames, plotFunction))
   user  system elapsed 
  0.471   0.004   0.488 

运行 在您的 for() 循环中相同的基准给出:

   user  system elapsed 
  3.945   0.616   3.519 

如您所见,只需将循环代码粘贴到一个函数中,稍作修改,速度就会提高大约 10 倍。

如果你想要更高的速度,你可以对你的函数做一些调整,但也许 apply() 方法最强大的方面是它很适合并行化,这是可以做到的像 pbmcapply

这样的包

实施 pbmcapply 让您更快;

library(pdp)
library(xgboost)
library(Matrix)
library(ggplot2)
library(plotly)
library(pbmcapply)

# Determines the number of cores you want to use for paralell processing
# I like to leave two of mine available, but you can get away with 1
nCores <-  detectCores() - 1

data(mtcars)
target <- mtcars$mpg
mtcars$mpg <- NULL

mtcars.sparse <- sparse.model.matrix(target~., mtcars)

fit <- xgboost(data=mtcars.sparse, label=target, nrounds=100)

varNames <- 
  names(mtcars) %>%
  as.list

plotFunction <- 
  function(x) {
    p1 <- pdp::partial(fit,
                       pred.var = x,
                       pred.grid = data.frame(unique(mtcars[x])),
                       train = mtcars.sparse,
                       type = "regression",
                       cats = c("cyl", "vs", "am", "gear", "carb"),
                       plot = FALSE)
    p2 <- ggplot(aes_string(x = x , y = "yhat"), data = p1) +
      geom_line(color = '#E51837', size = .6) +
      labs(title = paste("Partial Dependence plot of", x , sep = " ")) +
      theme(text = element_text(color = "#444444", family = 'Helvetica Neue'),
            plot.title = element_text(size = 13, color = '#333333'))
    return(p2)
  }


plot.list <- pbmclapply(varNames, plotFunction, mc.cores = nCores)

让我们看看效果如何

   user  system elapsed 
  0.842   0.458   0.320 

相对于 lapply() 的小改进,但该改进应随更大的数据集扩展。希望这对您有所帮助!