numba 中两个列表的交集
Intersection of two lists in numba
我想知道在 numba 函数中计算两个列表的交集的最快方法。只是为了澄清:两个列表的交集示例:
Input :
lst1 = [15, 9, 10, 56, 23, 78, 5, 4, 9]
lst2 = [9, 4, 5, 36, 47, 26, 10, 45, 87]
Output :
[9, 10, 4, 5]
问题是,这需要在 numba 函数中计算,因此例如套不能用。你有好主意吗?
我当前的代码非常基础。我认为还有改进的余地。
@nb.njit
def intersection:
result = []
for element1 in lst1:
for element2 in lst2:
if element1 == element2:
result.append(element1)
....
您可以为此使用集合操作:
def intersection(lst1, lst2):
return list(set(lst1) & set(lst2))
然后简单地调用函数intersection(lst1,lst2)
。这将是最简单的方法。
由于 numba 以机器码形式编译和运行您的代码,因此您可能最适合这种简单的操作。
我 运行 下面的一些基准测试
@nb.njit
def loop_intersection(lst1, lst2):
result = []
for element1 in lst1:
for element2 in lst2:
if element1 == element2:
result.append(element1)
return result
@nb.njit
def set_intersect(lst1, lst2):
return set(lst1).intersection(set(lst2))
结果
loop_intersection
40.4 µs ± 1.5 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
set_intersect
42 µs ± 6.74 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
我试了一下这个以尝试学习一些东西,意识到答案已经给出了。当我 运行 接受的答案时,我得到 return 值 [9, 10, 5, 4, 9]。我不清楚重复的 9 是否可以接受。
假设没问题,我 运行 尝试使用列表理解来查看它是否有所不同。我的结果:
from numba import jit
def createLists():
l1 = [15, 9, 10, 56, 23, 78, 5, 4, 9]
l2 = [9, 4, 5, 36, 47, 26, 10, 45, 87]
@jit
def listComp():
l1, l2 = createLists()
return [i for i in l1 for j in l2 if i == j]
%timeit listComp()
5.84 微秒 +/- 10.5 纳秒
或者,如果您可以使用 Numpy,此代码会更快并删除重复的“9”,并且使用 Numba 签名会更快。
import numpy as np
from numba import jit, int64
@jit(int64[:](int64[:], int64[:]))
def JitListComp(l1, l2):
l3 = np.array([i for i in l1 for j in l2 if i == j])
return np.unique(l3) # and i not in crossSec]
@jit
def CreateList():
l1 = np.array([15, 9, 10, 56, 23, 78, 5, 4, 9])
l2 = np.array([9, 4, 5, 36, 47, 26, 10, 45, 87])
return JitListComp(l1, l2)
CreateList()
Out[39]: array([ 4, 5, 9, 10])
%timeit CreateList()
1.71 µs ± 10.4 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)
我想知道在 numba 函数中计算两个列表的交集的最快方法。只是为了澄清:两个列表的交集示例:
Input :
lst1 = [15, 9, 10, 56, 23, 78, 5, 4, 9]
lst2 = [9, 4, 5, 36, 47, 26, 10, 45, 87]
Output :
[9, 10, 4, 5]
问题是,这需要在 numba 函数中计算,因此例如套不能用。你有好主意吗? 我当前的代码非常基础。我认为还有改进的余地。
@nb.njit
def intersection:
result = []
for element1 in lst1:
for element2 in lst2:
if element1 == element2:
result.append(element1)
....
您可以为此使用集合操作:
def intersection(lst1, lst2):
return list(set(lst1) & set(lst2))
然后简单地调用函数intersection(lst1,lst2)
。这将是最简单的方法。
由于 numba 以机器码形式编译和运行您的代码,因此您可能最适合这种简单的操作。 我 运行 下面的一些基准测试
@nb.njit
def loop_intersection(lst1, lst2):
result = []
for element1 in lst1:
for element2 in lst2:
if element1 == element2:
result.append(element1)
return result
@nb.njit
def set_intersect(lst1, lst2):
return set(lst1).intersection(set(lst2))
结果
loop_intersection
40.4 µs ± 1.5 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
set_intersect
42 µs ± 6.74 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
我试了一下这个以尝试学习一些东西,意识到答案已经给出了。当我 运行 接受的答案时,我得到 return 值 [9, 10, 5, 4, 9]。我不清楚重复的 9 是否可以接受。 假设没问题,我 运行 尝试使用列表理解来查看它是否有所不同。我的结果:
from numba import jit
def createLists():
l1 = [15, 9, 10, 56, 23, 78, 5, 4, 9]
l2 = [9, 4, 5, 36, 47, 26, 10, 45, 87]
@jit
def listComp():
l1, l2 = createLists()
return [i for i in l1 for j in l2 if i == j]
%timeit listComp() 5.84 微秒 +/- 10.5 纳秒
或者,如果您可以使用 Numpy,此代码会更快并删除重复的“9”,并且使用 Numba 签名会更快。
import numpy as np
from numba import jit, int64
@jit(int64[:](int64[:], int64[:]))
def JitListComp(l1, l2):
l3 = np.array([i for i in l1 for j in l2 if i == j])
return np.unique(l3) # and i not in crossSec]
@jit
def CreateList():
l1 = np.array([15, 9, 10, 56, 23, 78, 5, 4, 9])
l2 = np.array([9, 4, 5, 36, 47, 26, 10, 45, 87])
return JitListComp(l1, l2)
CreateList()
Out[39]: array([ 4, 5, 9, 10])
%timeit CreateList()
1.71 µs ± 10.4 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)