numba 中两个列表的交集

Intersection of two lists in numba

我想知道在 numba 函数中计算两个列表的交集的最快方法。只是为了澄清:两个列表的交集示例:

Input : 
lst1 = [15, 9, 10, 56, 23, 78, 5, 4, 9]
lst2 = [9, 4, 5, 36, 47, 26, 10, 45, 87]
Output :
[9, 10, 4, 5]

问题是,这需要在 numba 函数中计算,因此例如套不能用。你有好主意吗? 我当前的代码非常基础。我认为还有改进的余地。

@nb.njit
def intersection:
   result = []
   for element1 in lst1:
      for element2 in lst2:
         if element1 == element2:
            result.append(element1)
   ....

您可以为此使用集合操作:

def intersection(lst1, lst2): 
    return list(set(lst1) & set(lst2))

然后简单地调用函数intersection(lst1,lst2)。这将是最简单的方法。

由于 numba 以机器码形式编译和运行您的代码,因此您可能最适合这种简单的操作。 我 运行 下面的一些基准测试

@nb.njit
def loop_intersection(lst1, lst2):
    result = []
    for element1 in lst1:
        for element2 in lst2:
            if element1 == element2:
                result.append(element1)
    return result

@nb.njit
def set_intersect(lst1, lst2):
    return set(lst1).intersection(set(lst2))

结果

loop_intersection
40.4 µs ± 1.5 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)

set_intersect
42 µs ± 6.74 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)

我试了一下这个以尝试学习一些东西,意识到答案已经给出了。当我 运行 接受的答案时,我得到 return 值 [9, 10, 5, 4, 9]。我不清楚重复的 9 是否可以接受。 假设没问题,我 运行 尝试使用列表理解来查看它是否有所不同。我的结果:

from numba import jit

def createLists():
    l1 = [15, 9, 10, 56, 23, 78, 5, 4, 9]
    l2 = [9, 4, 5, 36, 47, 26, 10, 45, 87]

@jit
def listComp():
    l1, l2 = createLists()
    return [i for i in l1 for j in l2 if i == j]

%timeit listComp() 5.84 微秒 +/- 10.5 纳秒

或者,如果您可以使用 Numpy,此代码会更快并删除重复的“9”,并且使用 Numba 签名会更快。

import numpy as np
from numba import jit, int64

@jit(int64[:](int64[:], int64[:]))
def JitListComp(l1, l2):
    l3 = np.array([i for i in l1 for j in l2 if i == j])
    return np.unique(l3) # and i not in crossSec]

@jit
def CreateList():
    l1 = np.array([15, 9, 10, 56, 23, 78, 5, 4, 9])
    l2 = np.array([9, 4, 5, 36, 47, 26, 10, 45, 87])
    return JitListComp(l1, l2)

CreateList()
Out[39]: array([ 4,  5,  9, 10])

%timeit CreateList()
1.71 µs ± 10.4 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)