为什么我的 dplyr 代码使用 mutate 和 zoo 创建多个变量非常慢?

Why is my dplyr code to create multiple variables using mutate and zoo incredibly slow?

我正在使用 dplyr 通过 mutate 在我的数据框中创建多个变量。同时,我使用 zoo 来计算滚动平均值。例如,我的变量设置如下:

vars <- "total_apples", "total_oranges", "total_bananas"

我的数据有超过 100 个变量和大约。 40,000行,但以上只是一个例子。

我正在使用以下代码:

library(dplyr)
library(zoo)
data <- data %>%
  group_by(fruit) %>%
  mutate(across(c(all_of(vars)), list(avge_last2 = ~ zoo::rollapplyr(., 2, FUN = mean, partial = TRUE))))

仅以上计算最近 2​​ 条记录的平均值,需要 5 分钟:

> end.time <- Sys.time()
> time.taken <- end.time - start.time
> time.taken
Time difference of 5.925337 mins

如果我想对更多记录进行平均,则需要更长的时间,比如 n=10,如下所示:

library(dplyr)
library(zoo)
data <- data %>%
  group_by(fruit) %>%
  mutate(across(c(all_of(vars)), list(avge_last2 = ~ zoo::rollapplyr(., 10, FUN = mean, partial = TRUE))))

我的代码有问题还是其他原因?

dput(head(data,20)) 提供以下内容:

structure(list(match_id = c(14581L, 14581L, 14581L, 14581L, 14581L, 
14581L, 14581L, 14581L, 14581L, 14581L, 14581L, 14581L, 14581L, 
14581L, 14581L, 14581L, 14581L, 14581L, 14581L, 14581L), match_date = structure(c(16527, 
16527, 16527, 16527, 16527, 16527, 16527, 16527, 16527, 16527, 
16527, 16527, 16527, 16527, 16527, 16527, 16527, 16527, 16527, 
16527), class = "Date"), season = c(2015, 2015, 2015, 2015, 2015, 
2015, 2015, 2015, 2015, 2015, 2015, 2015, 2015, 2015, 2015, 2015, 
2015, 2015, 2015, 2015), match_round = c(1, 1, 1, 1, 1, 1, 1, 
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1), home_team = c(3, 3, 3, 
3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3), away_team = c(14, 
14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 
14, 14, 14), venue = c(11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 
11, 11, 11, 11, 11, 11, 11, 11, 11, 11), venue_name = c("MCG", 
"MCG", "MCG", "MCG", "MCG", "MCG", "MCG", "MCG", "MCG", "MCG", 
"MCG", "MCG", "MCG", "MCG", "MCG", "MCG", "MCG", "MCG", "MCG", 
"MCG"), opponent = c(14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 
14, 14, 14, 14, 14, 14, 14, 14, 14, 14), player_id = c(11186L, 
11215L, 11285L, 11330L, 11380L, 11388L, 11407L, 11472L, 11473L, 
11490L, 11553L, 11561L, 11573L, 11582L, 11598L, 11601L, 11616L, 
11643L, 11671L, 11737L), player_first_name = c("Chris", "Chris", 
"Kade", "Troy", "Andrew", "Brett", "Cameron", "Marc", "Dale", 
"Ivan", "Bryce", "Shane", "Bachar", "Jack", "Andrejs", "Shaun", 
"Michael", "Lachie", "Trent", "Alex"), player_last_name = c("Judd", 
"Newman", "Simpson", "Chaplin", "Carrazzo", "Deledio", "Wood", 
"Murphy", "Thomas", "Maric", "Gibbs", "Edwards", "Houli", "Riewoldt", 
"Everitt", "Grigg", "Jamison", "Henderson", "Cotchin", "Rance"
), player_team = c("Carlton", "Richmond", "Carlton", "Richmond", 
"Carlton", "Richmond", "Carlton", "Carlton", "Carlton", "Richmond", 
"Carlton", "Richmond", "Richmond", "Richmond", "Carlton", "Richmond", 
"Carlton", "Carlton", "Richmond", "Richmond"), player_team_numeric = c(3, 
14, 3, 14, 3, 14, 3, 3, 3, 14, 3, 14, 14, 14, 3, 14, 3, 3, 14, 
14), guernsey_number = c(5L, 1L, 6L, 25L, 44L, 3L, 36L, 3L, 39L, 
20L, 4L, 10L, 14L, 8L, 33L, 6L, 40L, 23L, 9L, 18L), player_position = c(3, 
14, 14, 1, 17, 13, 16, 12, 20, 16, 14, 5, 10, 8, 13, 14, 6, 7, 
3, 2), disposals = c(21L, 7L, 21L, 13L, 18L, 18L, 11L, 21L, 1L, 
13L, 26L, 21L, 21L, 17L, 18L, 17L, 8L, 10L, 17L, 18L), kicks = c(16L, 
6L, 13L, 9L, 9L, 9L, 8L, 9L, 1L, 8L, 15L, 9L, 15L, 13L, 14L, 
9L, 4L, 9L, 6L, 9L), marks = c(5L, 1L, 8L, 1L, 2L, 3L, 2L, 2L, 
0L, 4L, 4L, 1L, 5L, 8L, 8L, 4L, 2L, 6L, 3L, 4L), handballs = c(5L, 
1L, 8L, 4L, 9L, 9L, 3L, 12L, 0L, 5L, 11L, 12L, 6L, 4L, 4L, 8L, 
4L, 1L, 11L, 9L), tackles = c(6L, 1L, 2L, 2L, 2L, 0L, 1L, 2L, 
0L, 4L, 4L, 3L, 1L, 0L, 2L, 2L, 1L, 2L, 1L, 0L), clearances = c(6L, 
0L, 0L, 0L, 6L, 1L, 6L, 4L, 0L, 4L, 4L, 7L, 0L, 0L, 1L, 3L, 0L, 
0L, 1L, 1L), brownlow_votes = c(0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 
0L, 0L, 1L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L), effective_disposals = c(15L, 
6L, 16L, 11L, 16L, 13L, 6L, 14L, 1L, 11L, 13L, 16L, 16L, 10L, 
14L, 12L, 5L, 6L, 9L, 17L), disposal_efficiency_percentage = c(71L, 
86L, 76L, 85L, 89L, 72L, 55L, 67L, 100L, 85L, 50L, 76L, 76L, 
59L, 78L, 71L, 63L, 60L, 53L, 94L), contested_possessions = c(11L, 
3L, 5L, 7L, 9L, 6L, 7L, 9L, 1L, 9L, 9L, 15L, 1L, 7L, 3L, 4L, 
3L, 4L, 5L, 5L), uncontested_possessions = c(10L, 4L, 17L, 6L, 
10L, 12L, 4L, 12L, 0L, 4L, 17L, 7L, 18L, 9L, 14L, 11L, 5L, 7L, 
12L, 14L), time_on_ground_percentage = c(79L, 65L, 73L, 100L, 
76L, 69L, 89L, 81L, 1L, 88L, 73L, 83L, 85L, 98L, 95L, 81L, 96L, 
91L, 86L, 96L), afl_fantasy_score = c(93L, 26L, 97L, 42L, 54L, 
53L, 61L, 67L, 4L, 91L, 96L, 67L, 78L, 89L, 80L, 80L, 30L, 54L, 
54L, 58L), contested_marks = c(0L, 0L, 0L, 0L, 0L, 1L, 1L, 0L, 
0L, 2L, 1L, 0L, 1L, 3L, 0L, 0L, 0L, 1L, 0L, 0L), metres_gained = c(474L, 
231L, 269L, 165L, 128L, 181L, 151L, 227L, -7L, 160L, 466L, 332L, 
709L, 268L, 464L, 283L, 99L, 257L, 203L, 288L), turnovers = c(5L, 
3L, 4L, 2L, 3L, 2L, 2L, 4L, 0L, 1L, 6L, 2L, 5L, 8L, 5L, 2L, 2L, 
3L, 3L, 1L), effective_kicks = c(11L, 5L, 9L, 7L, 7L, 4L, 3L, 
5L, 1L, 6L, 5L, 4L, 11L, 7L, 12L, 5L, 2L, 6L, 1L, 9L), ground_ball_gets = c(8L, 
2L, 4L, 5L, 7L, 4L, 4L, 8L, 0L, 3L, 6L, 9L, 0L, 4L, 3L, 2L, 2L, 
2L, 5L, 3L), cum_rec = c(1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 
13, 14, 15, 16, 17, 18, 19, 20), rank_match_kicks = c(2, 34, 
10.5, 20.5, 20.5, 20.5, 28, 20.5, 43, 28, 4.5, 20.5, 4.5, 10.5, 
8, 20.5, 39.5, 20.5, 34, 20.5), rank_match_marks = c(14, 39, 
5, 39, 33, 27.5, 33, 33, 43.5, 20.5, 20.5, 39, 14, 5, 5, 20.5, 
33, 10, 27.5, 20.5)), row.names = c(NA, -20L), class = c("tbl_df", 
"tbl", "data.frame"))

更新:

使用下面答案中建议的函数考虑下面的示例:

match_id <- c("match_1", "match_1","match_1","match_2","match_2","match_2","match_3","match_3","match_3")
player_id <- c("player_1", "player_2", "player_3", "player_1", "player_2", "player_3", "player_1", "player_2", "player_3")
turnovers <- c(5,10,15,1,2,3,5,7,9)

data <- data.frame(match_id, player_id, turnovers)
    
f <- function(dt, window, vars, byvars, partial=F) {
  res = dt[, lapply(.SD, frollmean, n=window), by=byvars, .SDcols=vars]
  if(partial) {
    res = rbind(
      partials(dt,window-1,vars, byvars),
      res[window:.N, .SD, by=byvars]
    )
  }
  return(res)
}

partials <- function(dt,w,vars,byvars) {
  rbindlist(lapply(1:w, function(p) {
    dt[1:p, lapply(.SD, function(v) Reduce(`+`, shift(v,0:(p-1)))/p),
       .SDcols = vars, by=byvars][p:.N, .SD, by=byvars]
  }))
}

# set the data as data.table
setDT(data)

# define vars of interest
vars = c("turnovers")

# ensure the order is correct for rolling mean
setorder(data, player_id, match_id )

# set the window size
n=3

# get the rolling mean, by grouping variable, for each var in `vars`, and add the partials

f(data, window=n, vars=vars, byvars="player_id", partial=T)

这个 returns 以下:

   player_id turnovers
1:  player_1  5.000000
2:  player_1  3.000000
3:  player_1  3.666667
4:  player_2        NA
5:  player_2        NA
6:  player_2  6.333333
7:  player_3        NA
8:  player_3        NA
9:  player_3  9.000000

我做错了什么?

我发现在 dplyer 中处理分组的数据帧确实会减慢速度,我不确定这是否是最好的解决方法,但是当我完成分组后我通过管道输入

%>% as.data.frame()

去掉分组信息,然后再做我的计算。它可以节省很多时间。如果您之前对大型数据集进行过分组,请尝试一下。

你可以试试这个:

library(data.table)

setDT(data)


data[,paste0(vars, "_avge_last2_"):= lapply(.SD, frollmean, n=2),
     .SDcols=vars,
     by=.(fruit)
]

更新

这是一个更通用的解决方案,用于处理每个 window 顶部的 NA(即部分 windows)

首先,从一个函数开始,它可以取一个数据table(dt),一个window大小(window),一组变量(vars),以及一组分组变量 (byvars),以及一个可选的逻辑指示符 partial

f <- function(dt, window, vars, byvars, partial=F) {
  res = dt[, lapply(.SD, frollmean, n=window), by=byvars, .SDcols=vars]
  if(partial) {
    res = rbind(
      partials(dt,window-1,vars, byvars),
      res[,.SD[window:.N], by=byvars]
    )
  }
  return(res)
}

添加,可选功能partials()

partials <- function(dt,w,vars,byvars) {
  rbindlist(lapply(1:w, function(p) {
    dt[, lapply(.SD[1:p], function(v) Reduce(`+`, shift(v,0:(p-1)))/p),
       .SDcols = vars, by=byvars][, .SD[p:.N], by=byvars]
  }))
}

应用函数

# set the data as data.table
setDT(data)

# define vars of interest
vars = c("turnovers", "effective_kicks")

# ensure the order is correct for rolling mean
setorder(data, match_id, player_id)

# set the window size
n=3

# get the rolling mean, by grouping variable, for each var in `vars`, and add the partials

f(data, window=n, vars=vars, byvars="player_id", partial=T)

有几个问题:

  • 问题中的代码不适用于提供的数据,而是会出错。数据中没有 fruit 列, vars 列也不存在。为了使其成为 运行,我们按 match_id 分组并定义 vars 以包含一些现有列。

  • 最好不要覆盖数据,而是为输出使用不同的名称,以便于调试。

  • 使用 across 导致 rollapplyr 分别应用于每个列,考虑到 rollapply 可以一次处理多个列,这是低效的。

使用所提供数据中实际存在的列,并假设我们想在 vars 中命名的列上使用 rollapplyr 试试这个只有 运行s rollapplyr每组一次,似乎速度稍快。

还使用 fill=NA 代替 partial=TRUE 它将使用更快的算法;但是,在这种情况下,每组中的第一行将具有 NA,因为这就是 fill=NA 的含义,并且如果要平均的列中已经存在 NA,则不会使用该算法。

library(dplyr)
library(zoo)

vars <- c("home_team", "away_team")

data_out <- data %>%
  group_by(match_id) %>%
  data.frame(avg = rollapplyr(.[vars], 2, mean, partial = TRUE)) %>%
  ungroup