加快矢量查找 data.table

Speed up vector-lookup for data.table

data.table 中的每一行,我需要从向量中找到最接近的较小数字。以下最小工作示例可以完成这项工作,但速度太慢,尤其是对于较长的 pre.numbers 向量(实际数据中约有 100 万个元素)。

library(data.table)
set.seed(2)
pre.numbers <- sort(floor(runif(50000, 1, 1000000)))
the.table <- data.table(cbind(rowid=1:10000, current.number=floor(runif(1000, 1, 100000)), closest.lower.number=NA_integer_))
setkey(the.table, rowid)
the.table[, closest.lower.number:=max(pre.numbers[pre.numbers<current.number]), by=rowid]

一定有更聪明的方法来做到这一点。 vector-numbers 和 data.table.

中的数字之间没有关系

这是一个向量化的解决方案:

algo1 = function()
{
    vec     = the.table$current.number
    indices = findInterval(vec-0.1, pre.numbers)
    res     = ifelse(indices==0, 0, vec[indices])

    the.table$closest.lower.number = res
}

algo2 = function()
{
    setkey(the.table, rowid)
    the.table[, closest.lower.number:=max(pre.numbers[pre.numbers<current.number]), by=rowid]
}

在我的机器上:

t1 = system.time(algo1())
#> t1
#user  system elapsed 
#0.0      0.0     0.0 

t2 = system.time(algo2())
#> t2
#user  system elapsed 
#9.73    0.00    9.73 

这个怎么样?使用 data.table 的滚动连接:

DT = data.table(pre = pre.numbers, 
       current.number = pre.numbers+0.5, key="current.number")
setkey(the.table, current.number)
ans = DT[the.table, roll=Inf, rollends=FALSE]

由于您处理的是整数,我刚刚添加了 0.5(0 到 1 之间的任何数字都可以)以从 pre.numbers.

创建 DT

最后一步执行 LOCF 滚动连接(上次观察结转)。对于 current.number(键列)的每个值,在 DTcurrent.number(键列)中查找匹配的行。如果没有匹配项,则前滚最后一次观察。如果匹配发生在 start/end,则结果为 NA (rollends = FALSE).

为了更好地说明发生了什么,请考虑以下情况:

# pre.numbers:
# c(10, 11)

# the.table:
# current.numbers
#               9
#              10
#              11

我们首先将 pre.numbers 转换为 DT,这将导致列

# DT:
# pre  current.numbers (key col)
#  10             10.5
#  11             11.5

对于 the.table 中的每个值:

#  9 -> falls before 10.5, LOCF and rollends = FALSE => result is NA
# 10 -> falls before 10.5 => same as above
# 11 -> falls between 10.5 and 11.5, LOCF matches with previous row = 10.5 
#       corresponding pre = 10.

HTH


这是我用来生成数据的代码:

require(data.table)
set.seed(1L)
pre.numbers = sort(floor(runif(50000, 1, 1000000)))
the.table = data.table(rowid=1:10000, current.number=floor(runif(1000, 1, 100000)))