在列中定义的 window 内求和

Sum within a window that is defined in column

我想为组中的每个 data.table 行实现 sum(x) N 下一行,其中 N 是来自 window 列的值。

生成示例数据的代码:

set.seed(100)
ids <- 1:100
x <- floor(runif(100, 1, 100))
groups <- floor(runif(100, 1, 10)) * 10
window <- floor(runif(100, 1, 5))

library('data.table')
data <- data.table(ids, x, groups, window)
setkey(data, groups, ids)

顶行:

 ids  x groups window
 1:   3 55     10      4
 2:   9 55     10      1
 3:  13 28     10      1
 4:  16 67     10      3
 5:  26 17     10      3
 6:  30 28     10      2
 7:  36 89     10      2
 8:  38 63     10      3
 9:  42 86     10      3
10:  48 88     10      1
11:  49 21     10      1
12:  59 60     10      3
13:  65 45     10      4
14:  67 46     10      2
15:  88 25     10      4
16:  19 36     20      2

因此,对于第一行,结果值将根据当前和接下来的 4 行的总和进行计算:res = 55 + 55 + 28 + 67 + 17 = 222

对于组结束的第 15 行,我只需要当前行的值:res = 25 + 0(无行)= 25。

这是此逻辑的伪代码:

res <- data[, .(ids, groups, x, window , 
            result = sum(.SD[.CurrentRow:(.CurrentRow + Window)], na.rm = T)), 
            by = groups, .SDcols = c("x")]

我希望这可以通过 data.table 实现。我想为此避免 for 循环实现。

首先我们加载 base 包并将我们的 data.table 转换为 data.frame

set.seed(100)
ids <- 1:100
x <- floor(runif(100, 1, 100))
groups <- floor(runif(100, 1, 10)) * 10
window <- floor(runif(100, 1, 5))

library('data.table')
data <- data.table(ids, x, groups, window)
setkey(data, groups, ids)

dd <- as.data.frame(data)

并且基本上将行绑定到一个更大的数据框中,我们可以用它来使用您最喜欢的聚合方法进行汇总

tmp <- tapply(seq(nrow(dd)), dd$groups, function(ii) {
  idx <- Map(`:`, ii, ii + dd$window[ii])
  out <- dd[unlist(idx), ]
  out$idx <- rep(dd$ids[ii], lengths(idx))
  out[out$groups %in% dd$groups[ii], ]
})
tmp <- do.call('rbind', tmp)

res <- aggregate(x ~ idx + groups, tmp, sum)

#    idx groups   x
# 1    3     10 222
# 2    9     10  83
# 3   13     10  95
# 4   16     10 201
# 5   26     10 197
# 6   30     10 180
# 7   36     10 238
# 8   38     10 258
# 9   42     10 255
# 10  48     10 109
# 11  49     10  81
# 12  59     10 176
# 13  65     10 116
# 14  67     10  71
# 15  88     10  25
# 16  19     20 173

identical(table(dd$groups), table(res$group))
# [1] TRUE

我不确定是否可以在不遍历所有行的情况下执行此操作,所以这是一个这样的解决方案:

data[, end := pmin(.I + window, .I[.N]), by = groups][
     , res := sum(data$x[.I:end]), by = 1:nrow(data)][1:16]
#    ids  x groups window end res
# 1:   3 55     10      4   5 222
# 2:   9 55     10      1   3  83
# 3:  13 28     10      1   4  95
# 4:  16 67     10      3   7 201
# 5:  26 17     10      3   8 197
# 6:  30 28     10      2   8 180
# 7:  36 89     10      2   9 238
# 8:  38 63     10      3  11 258
# 9:  42 86     10      3  12 255
#10:  48 88     10      1  11 109
#11:  49 21     10      1  12  81
#12:  59 60     10      3  15 176
#13:  65 45     10      4  15 116
#14:  67 46     10      2  15  71
#15:  88 25     10      4  15  25
#16:  19 36     20      2  18 173

正如 alexis_laz 指出的那样,您可以通过计算 cumsum 一次然后减去多余的部分来做得更好,从而避免显式迭代行:

data[, res := { cs <- cumsum(x); 
                cs[pmin(1:.N + window, .N)] - shift(cs, fill = 0)}
     , by = groups]

我将尝试解释这里发生的事情:

  • res := {...} 向我们的 data.table 添加一列,括号内为 R 计算;
  • cs = cumsum(x) 计算组内所有行的 运行 总和;
  • cs[pmin(1:.N + window, .N)] 在 window 末尾或组的最后一行取 运行 总和的值;
  • shift(cs, fill = 0) 从前一行得到 运行 总和;
  • 两者之差给出了 window 内项目的总和。

由于这个问题有几个有效的答案,我认为值得在这里提供基准测试:

library(microbenchmark)
m <- microbenchmark(
               "tapply(rawr)" = tapplyWay(dd),
               "data.table(eddi)" = data[, end := pmin(.I + window, .I[.N]), by = groups][
                   , res := sum(data$x[.I:end]), by = 1:nrow(data)],
               "data.table(alexis_laz)" = data[, res := {cs = cumsum(x); cs[pmin(1:.N + window, .N)] - shift(cs, fill = 0)}
                                               , by = groups],
               times = 10)
print(m)
boxplot(m)

10^5 行示例的结果:

Unit: milliseconds
            expr     min     lq      mean    median      uq    max        neval
       tapply(rawr) 2575.12 2761.365 2898.63 2905.77  3041.08  3127.86    10
   data.table(eddi) 1418.92 1570.230 1758.70 1656.14  1977.59  2358.85    10
     dt(alexis_laz) 6.82    7.73     8.78    8.09     10.37    12.37119    10