data.table |组内更快的逐行递归更新
data.table | faster row-wise recursive update within group
我必须进行以下逐行递归操作才能获得z
:
myfun = function (xb, a, b) {
z = NULL
for (t in 1:length(xb)) {
if (t >= 2) { a[t] = b[t-1] + xb[t] }
z[t] = rnorm(1, mean = a[t])
b[t] = a[t] + z[t]
}
return(z)
}
set.seed(1)
n_smpl = 1e6
ni = 5
id = rep(1:n_smpl, each = ni)
smpl = data.table(id)
smpl[, time := 1:.N, by = id]
a_init = 1; b_init = 1
smpl[, ':=' (a = a_init, b = b_init)]
smpl[, xb := (1:.N)*id, by = id]
smpl[, z := myfun(xb, a, b), by = id]
我想得到这样的结果:
id time a b xb z
1: 1 1 1 1 1 0.3735462
2: 1 2 1 1 2 2.7470924
3: 1 3 1 1 3 8.4941848
4: 1 4 1 1 4 20.9883695
5: 1 5 1 1 5 46.9767390
---
496: 100 1 1 1 100 0.3735462
497: 100 2 1 1 200 200.7470924
498: 100 3 1 1 300 701.4941848
499: 100 4 1 1 400 1802.9883695
500: 100 5 1 1 500 4105.9767390
这确实有效,但需要时间:
system.time(smpl[, z := myfun(xb, a, b), by = id])
user system elapsed
33.646 0.994 34.473
考虑到我的实际数据量(超过 200 万次观察),我需要让它更快。我想 do.call(myfun, .SD), .SDcols = c('xb', 'a', 'b')
和 by = .(id, time)
会快得多,避免了 myfun
中的 for 循环。但是,当我 运行 在 data.table
中进行逐行操作时,我不确定如何更新 b
及其滞后(可能使用 shift
)。有什么建议吗?
好问题!
从一个新的 R 会话开始,显示 500 万行的演示数据,这是问题的函数和我笔记本电脑上的时间。内联一些评论。
require(data.table) # v1.10.0
n_smpl = 1e6
ni = 5
id = rep(1:n_smpl, each = ni)
smpl = data.table(id)
smpl[, time := 1:.N, by = id]
a_init = 1; b_init = 1
smpl[, ':=' (a = a_init, b = b_init)]
smpl[, xb := (1:.N)*id, by = id]
myfun = function (xb, a, b) {
z = NULL
# initializes a new length-0 variable
for (t in 1:length(xb)) {
if (t >= 2) { a[t] = b[t-1] + xb[t] }
# if() on every iteration. t==1 could be done before loop
z[t] = rnorm(1, mean = a[t])
# z vector is grown by 1 item, each time
b[t] = a[t] + z[t]
# assigns to all of b vector when only really b[t-1] is
# needed on the next iteration
}
return(z)
}
set.seed(1); system.time(smpl[, z := myfun(xb, a, b), by = id][])
user system elapsed
19.216 0.004 19.212
smpl
id time a b xb z
1: 1 1 1 1 1 3.735462e-01
2: 1 2 1 1 2 3.557190e+00
3: 1 3 1 1 3 9.095107e+00
4: 1 4 1 1 4 2.462112e+01
5: 1 5 1 1 5 5.297647e+01
---
4999996: 1000000 1 1 1 1000000 1.618913e+00
4999997: 1000000 2 1 1 2000000 2.000000e+06
4999998: 1000000 3 1 1 3000000 7.000003e+06
4999999: 1000000 4 1 1 4000000 1.800001e+07
5000000: 1000000 5 1 1 5000000 4.100001e+07
所以19.2s是打的时间。在所有这些计时中,我在本地 运行 命令 3 次以确保它是一个稳定的计时。时间差异在此任务中微不足道,因此我将只报告一个时间以使答案更快阅读。
处理上面 myfun()
中的注释:
myfun2 = function (xb, a, b) {
z = numeric(length(xb))
# allocate size up front rather than growing
z[1] = rnorm(1, mean=a[1])
prevb = a[1]+z[1]
t = 2L
while(t<=length(xb)) {
at = prevb + xb[t]
z[t] = rnorm(1, mean=at)
prevb = at + z[t]
t = t+1L
}
return(z)
}
set.seed(1); system.time(smpl[, z2 := myfun2(xb, a, b), by = id][])
user system elapsed
13.212 0.036 13.245
smpl[,identical(z,z2)]
[1] TRUE
非常好(从 19.2 秒降至 13.2 秒),但它仍然是 R 级的 for
循环。乍一看,它不能被矢量化,因为 rnorm()
调用依赖于先前的值。事实上,它可能可以通过使用 m+sd*rnorm(mean=0,sd=1) == rnorm(mean=m, sd=sd)
的 属性 并调用矢量化 rnorm(n=5e6)
一次而不是 5e6 次来对其进行矢量化。但可能会有 cumsum()
参与处理这些团体。所以我们不要去那里,因为这可能会使代码更难阅读,并且会特定于这个精确的问题。
那我们试试Rcpp吧,和你写的风格很相似,适用范围更广:
require(Rcpp) # v0.12.8
cppFunction(
'NumericVector myfun3(IntegerVector xb, NumericVector a, NumericVector b) {
NumericVector z = NumericVector(xb.length());
z[0] = R::rnorm(/*mean=*/ a[0], /*sd=*/ 1);
double prevb = a[0]+z[0];
int t = 1;
while (t<xb.length()) {
double at = prevb + xb[t];
z[t] = R::rnorm(at, 1);
prevb = at + z[t];
t++;
}
return z;
}')
set.seed(1); system.time(smpl[, z3 := myfun3(xb, a, b), by = id][])
user system elapsed
1.800 0.020 1.819
smpl[,identical(z,z3)]
[1] TRUE
好多了:19.2s 下降到 1.8s。但是每次调用该函数都会调用第一行 (NumericVector()
),它会分配一个与组中的行数一样长的新向量。然后填写并返回,复制到该组正确位置的最后一列(:=
),仅供发布。所有这 100 万个小型临时向量(每组一个)的分配和管理都有些复杂。
我们为什么不一次性完成整个专栏?您已经用 for 循环样式编写了它,这没有任何问题。让我们调整 C 函数以接受 id
列,并在它到达新组时添加 if
。
cppFunction(
'NumericVector myfun4(IntegerVector id, IntegerVector xb, NumericVector a, NumericVector b) {
// ** id must be pre-grouped, such as via setkey(DT,id) **
NumericVector z = NumericVector(id.length());
int previd = id[0]-1; // initialize to anything different than id[0]
for (int i=0; i<id.length(); i++) {
double prevb;
if (id[i]!=previd) {
// first row of new group
z[i] = R::rnorm(a[i], 1);
prevb = a[i]+z[i];
previd = id[i];
} else {
// 2nd row of group onwards
double at = prevb + xb[i];
z[i] = R::rnorm(at, 1);
prevb = at + z[i];
}
}
return z;
}')
system.time(setkey(smpl,id)) # ensure grouped by id
user system elapsed
0.028 0.004 0.033
set.seed(1); system.time(smpl[, z4 := myfun4(id, xb, a, b)][])
user system elapsed
0.232 0.004 0.237
smpl[,identical(z,z4)]
[1] TRUE
更好:19.2s 下降到 0.27s。
我必须进行以下逐行递归操作才能获得z
:
myfun = function (xb, a, b) {
z = NULL
for (t in 1:length(xb)) {
if (t >= 2) { a[t] = b[t-1] + xb[t] }
z[t] = rnorm(1, mean = a[t])
b[t] = a[t] + z[t]
}
return(z)
}
set.seed(1)
n_smpl = 1e6
ni = 5
id = rep(1:n_smpl, each = ni)
smpl = data.table(id)
smpl[, time := 1:.N, by = id]
a_init = 1; b_init = 1
smpl[, ':=' (a = a_init, b = b_init)]
smpl[, xb := (1:.N)*id, by = id]
smpl[, z := myfun(xb, a, b), by = id]
我想得到这样的结果:
id time a b xb z
1: 1 1 1 1 1 0.3735462
2: 1 2 1 1 2 2.7470924
3: 1 3 1 1 3 8.4941848
4: 1 4 1 1 4 20.9883695
5: 1 5 1 1 5 46.9767390
---
496: 100 1 1 1 100 0.3735462
497: 100 2 1 1 200 200.7470924
498: 100 3 1 1 300 701.4941848
499: 100 4 1 1 400 1802.9883695
500: 100 5 1 1 500 4105.9767390
这确实有效,但需要时间:
system.time(smpl[, z := myfun(xb, a, b), by = id])
user system elapsed
33.646 0.994 34.473
考虑到我的实际数据量(超过 200 万次观察),我需要让它更快。我想 do.call(myfun, .SD), .SDcols = c('xb', 'a', 'b')
和 by = .(id, time)
会快得多,避免了 myfun
中的 for 循环。但是,当我 运行 在 data.table
中进行逐行操作时,我不确定如何更新 b
及其滞后(可能使用 shift
)。有什么建议吗?
好问题!
从一个新的 R 会话开始,显示 500 万行的演示数据,这是问题的函数和我笔记本电脑上的时间。内联一些评论。
require(data.table) # v1.10.0
n_smpl = 1e6
ni = 5
id = rep(1:n_smpl, each = ni)
smpl = data.table(id)
smpl[, time := 1:.N, by = id]
a_init = 1; b_init = 1
smpl[, ':=' (a = a_init, b = b_init)]
smpl[, xb := (1:.N)*id, by = id]
myfun = function (xb, a, b) {
z = NULL
# initializes a new length-0 variable
for (t in 1:length(xb)) {
if (t >= 2) { a[t] = b[t-1] + xb[t] }
# if() on every iteration. t==1 could be done before loop
z[t] = rnorm(1, mean = a[t])
# z vector is grown by 1 item, each time
b[t] = a[t] + z[t]
# assigns to all of b vector when only really b[t-1] is
# needed on the next iteration
}
return(z)
}
set.seed(1); system.time(smpl[, z := myfun(xb, a, b), by = id][])
user system elapsed
19.216 0.004 19.212
smpl
id time a b xb z
1: 1 1 1 1 1 3.735462e-01
2: 1 2 1 1 2 3.557190e+00
3: 1 3 1 1 3 9.095107e+00
4: 1 4 1 1 4 2.462112e+01
5: 1 5 1 1 5 5.297647e+01
---
4999996: 1000000 1 1 1 1000000 1.618913e+00
4999997: 1000000 2 1 1 2000000 2.000000e+06
4999998: 1000000 3 1 1 3000000 7.000003e+06
4999999: 1000000 4 1 1 4000000 1.800001e+07
5000000: 1000000 5 1 1 5000000 4.100001e+07
所以19.2s是打的时间。在所有这些计时中,我在本地 运行 命令 3 次以确保它是一个稳定的计时。时间差异在此任务中微不足道,因此我将只报告一个时间以使答案更快阅读。
处理上面 myfun()
中的注释:
myfun2 = function (xb, a, b) {
z = numeric(length(xb))
# allocate size up front rather than growing
z[1] = rnorm(1, mean=a[1])
prevb = a[1]+z[1]
t = 2L
while(t<=length(xb)) {
at = prevb + xb[t]
z[t] = rnorm(1, mean=at)
prevb = at + z[t]
t = t+1L
}
return(z)
}
set.seed(1); system.time(smpl[, z2 := myfun2(xb, a, b), by = id][])
user system elapsed
13.212 0.036 13.245
smpl[,identical(z,z2)]
[1] TRUE
非常好(从 19.2 秒降至 13.2 秒),但它仍然是 R 级的 for
循环。乍一看,它不能被矢量化,因为 rnorm()
调用依赖于先前的值。事实上,它可能可以通过使用 m+sd*rnorm(mean=0,sd=1) == rnorm(mean=m, sd=sd)
的 属性 并调用矢量化 rnorm(n=5e6)
一次而不是 5e6 次来对其进行矢量化。但可能会有 cumsum()
参与处理这些团体。所以我们不要去那里,因为这可能会使代码更难阅读,并且会特定于这个精确的问题。
那我们试试Rcpp吧,和你写的风格很相似,适用范围更广:
require(Rcpp) # v0.12.8
cppFunction(
'NumericVector myfun3(IntegerVector xb, NumericVector a, NumericVector b) {
NumericVector z = NumericVector(xb.length());
z[0] = R::rnorm(/*mean=*/ a[0], /*sd=*/ 1);
double prevb = a[0]+z[0];
int t = 1;
while (t<xb.length()) {
double at = prevb + xb[t];
z[t] = R::rnorm(at, 1);
prevb = at + z[t];
t++;
}
return z;
}')
set.seed(1); system.time(smpl[, z3 := myfun3(xb, a, b), by = id][])
user system elapsed
1.800 0.020 1.819
smpl[,identical(z,z3)]
[1] TRUE
好多了:19.2s 下降到 1.8s。但是每次调用该函数都会调用第一行 (NumericVector()
),它会分配一个与组中的行数一样长的新向量。然后填写并返回,复制到该组正确位置的最后一列(:=
),仅供发布。所有这 100 万个小型临时向量(每组一个)的分配和管理都有些复杂。
我们为什么不一次性完成整个专栏?您已经用 for 循环样式编写了它,这没有任何问题。让我们调整 C 函数以接受 id
列,并在它到达新组时添加 if
。
cppFunction(
'NumericVector myfun4(IntegerVector id, IntegerVector xb, NumericVector a, NumericVector b) {
// ** id must be pre-grouped, such as via setkey(DT,id) **
NumericVector z = NumericVector(id.length());
int previd = id[0]-1; // initialize to anything different than id[0]
for (int i=0; i<id.length(); i++) {
double prevb;
if (id[i]!=previd) {
// first row of new group
z[i] = R::rnorm(a[i], 1);
prevb = a[i]+z[i];
previd = id[i];
} else {
// 2nd row of group onwards
double at = prevb + xb[i];
z[i] = R::rnorm(at, 1);
prevb = at + z[i];
}
}
return z;
}')
system.time(setkey(smpl,id)) # ensure grouped by id
user system elapsed
0.028 0.004 0.033
set.seed(1); system.time(smpl[, z4 := myfun4(id, xb, a, b)][])
user system elapsed
0.232 0.004 0.237
smpl[,identical(z,z4)]
[1] TRUE
更好:19.2s 下降到 0.27s。