Numpy:有效地找到行式公共元素
Numpy: find row-wise common element efficiently
假设我们有两个具有相同行数的二维 numpy 数组 a
和 b
。进一步假设我们知道 a
和 b
的每一行 i
至多有一个共同元素,尽管这个元素可能出现多次。我们怎样才能尽可能高效地找到这个元素?
一个例子:
import numpy as np
a = np.array([[1, 2, 3],
[2, 5, 2],
[5, 4, 4],
[2, 1, 3]])
b = np.array([[4, 5],
[3, 2],
[1, 5],
[0, 5]])
desiredResult = np.array([[np.nan],
[2],
[5],
[np.nan]])
通过沿第一个轴应用 intersect1d
很容易得出一个直接的实现:
from intertools import starmap
desiredResult = np.array(list(starmap(np.intersect1d, zip(a, b))))
显然,使用 python 的内置集合操作甚至更快。将结果转换为所需的形式很容易。
但是,我需要一个尽可能高效的实现。因此,我不喜欢 starmap
,因为我认为它需要对每一行调用 python。我想要一个纯矢量化的选项,并且会很高兴,如果这甚至利用了我们的额外知识,即每行最多有一个公共值。
有没有人知道我可以如何加速任务并更优雅地实施解决方案?我可以使用 C 代码或 cython,但编码工作应该不是太多了。
不确定这是否更快,但我们可以在这里尝试一些事情:
方法 1 np.intersect1d
带列表理解
[np.intersect1d(arr[0], arr[1]) for arr in list(zip(a,b))]
# Out
[array([], dtype=int32), array([2]), array([5]), array([], dtype=int32)]
或列出:
[np.intersect1d(arr[0], arr[1]).tolist() for arr in list(zip(a,b))]
# Out
[[], [2], [5], []]
方法 2 set
含列表理解:
[list(set(arr[0]) & set(arr[1])) for arr in list(zip(a,b))]
# Out
[[], [2], [5], []]
方法 #1
这是基于 -
的矢量化
# Sort each row of a and b in-place
a.sort(1)
b.sort(1)
# Use 2D searchsorted row-wise between a and b
idx = searchsorted2d(a,b)
# "Clip-out" out of bounds indices
idx[idx==a.shape[1]] = 0
# Get mask of valid ones i.e. matches
mask = np.take_along_axis(a,idx,axis=1)==b
# Use argmax to get first match as we know there's at most one match
match_val = np.take_along_axis(b,mask.argmax(1)[:,None],axis=1)
# Finally use np.where to choose between valid match
# (decided by any one True in each row of mask)
out = np.where(mask.any(1)[:,None],match_val,np.nan)
方法 #2
Numba-based 一个为了内存效率 -
from numba import njit
@njit(parallel=True)
def numba_f1(a,b,out):
n,a_ncols = a.shape
b_ncols = b.shape[1]
for i in range(n):
for j in range(a_ncols):
for k in range(b_ncols):
m = a[i,j]==b[i,k]
if m:
break
if m:
out[i] = a[i,j]
break
return out
def find_first_common_elem_per_row(a,b):
out = np.full(len(a),np.nan)
numba_f1(a,b,out)
return out
方法 #3
这是另一个基于堆叠和排序的矢量化 -
r = np.arange(len(a))
ab = np.hstack((a,b))
idx = ab.argsort(1)
ab_s = ab[r[:,None],idx]
m = ab_s[:,:-1] == ab_s[:,1:]
m2 = (idx[:,1:]*m)>=a.shape[1]
m3 = m & m2
out = np.where(m3.any(1),b[r,idx[r,m3.argmax(1)+1]-a.shape[1]],np.nan)
方法 #4
为了优雅,我们可以使用 broadcasting
作为 resource-hungry 方法 -
m = (a[:,None]==b[:,:,None]).any(2)
out = np.where(m.any(1),b[np.arange(len(a)),m.argmax(1)],np.nan)
做一些研究,我发现在 O(n+m) 中检查两个列表是否不相交 运行s,从而 n 和 m 是列表的长度(参见 here)。这个想法是在哈希映射的恒定时间内插入和查找元素运行。因此,将第一个列表中的所有元素插入哈希映射需要 O(n) 次操作,检查第二个列表中的每个元素是否已经在哈希映射中需要 O(m) 操作。因此,基于排序的解决方案 运行 在 O(n log(n) + m log(m)) 中不是渐近最优的。
尽管@Divakar 的解决方案在许多用例中都非常高效,但如果第二维很大,它们的效率就会降低。那么,基于哈希映射的解决方案更适合。我在cython:
中实现如下
import numpy as np
cimport numpy as np
import cython
from libc.math cimport NAN
from libcpp.unordered_map cimport unordered_map
np.import_array()
@cython.boundscheck(False)
@cython.wraparound(False)
def get_common_element2d(np.ndarray[double, ndim=2] arr1,
np.ndarray[double, ndim=2] arr2):
cdef np.ndarray[double, ndim=1] result = np.empty(arr1.shape[0])
cdef int dim1 = arr1.shape[1]
cdef int dim2 = arr2.shape[1]
cdef int i, j
cdef unordered_map[double, int] tmpset = unordered_map[double, int]()
for i in range(arr1.shape[0]):
for j in range(dim1):
# insert arr1[i, j] as key without assigned value
tmpset[arr1[i, j]]
for j in range(dim2):
# check whether arr2[i, j] is in tmpset
if tmpset.count(arr2[i,j]):
result[i] = arr2[i,j]
break
else:
result[i] = NAN
tmpset.clear()
return result
我创建了如下测试用例:
import numpy as np
import timeit
from itertools import starmap
from mycythonmodule import get_common_element2d
m, n = 3000, 3000
a = np.random.rand(m, n)
b = np.random.rand(m, n)
for i, row in enumerate(a):
if np.random.randint(2):
common = np.random.choice(row, 1)
b[i][np.random.choice(np.arange(n), np.random.randint(min(n,20)), False)] = common
# we need to copy the arrays on each test run, otherwise they
# will remain sorted, which would bias the results
%timeit [set(aa).intersection(bb) for aa, bb in zip(a.copy(), b.copy())]
# returns 3.11 s ± 56.8 ms
%timeit list(starmap(np.intersect1d, zip(a.copy(), b.copy)))
# returns 1.83 s ± 55.4
# test sorting method
# divakarsMethod1 is the appraoch #1 in @Divakar's answer
%timeit divakarsMethod1(a.copy(), b.copy())
# returns 1.88 s ± 18 ms
# test hash map method
%timeit get_common_element2d(a.copy(), b.copy())
# returns 1.46 s ± 22.6 ms
这些结果似乎表明朴素的方法实际上比某些矢量化版本更好。但是,如果考虑具有较少列的许多行(不同的用例),则矢量化算法会发挥其优势。在这些情况下,矢量化方法比朴素方法快 5 倍以上,结果证明排序方法是最好的。
结论: 我将使用 HashMap-based cython 版本,因为它是两种用例中最有效的变体之一。如果我必须先设置 cython,我会使用 sorting-based 方法。
假设我们有两个具有相同行数的二维 numpy 数组 a
和 b
。进一步假设我们知道 a
和 b
的每一行 i
至多有一个共同元素,尽管这个元素可能出现多次。我们怎样才能尽可能高效地找到这个元素?
一个例子:
import numpy as np
a = np.array([[1, 2, 3],
[2, 5, 2],
[5, 4, 4],
[2, 1, 3]])
b = np.array([[4, 5],
[3, 2],
[1, 5],
[0, 5]])
desiredResult = np.array([[np.nan],
[2],
[5],
[np.nan]])
通过沿第一个轴应用 intersect1d
很容易得出一个直接的实现:
from intertools import starmap
desiredResult = np.array(list(starmap(np.intersect1d, zip(a, b))))
显然,使用 python 的内置集合操作甚至更快。将结果转换为所需的形式很容易。
但是,我需要一个尽可能高效的实现。因此,我不喜欢 starmap
,因为我认为它需要对每一行调用 python。我想要一个纯矢量化的选项,并且会很高兴,如果这甚至利用了我们的额外知识,即每行最多有一个公共值。
有没有人知道我可以如何加速任务并更优雅地实施解决方案?我可以使用 C 代码或 cython,但编码工作应该不是太多了。
不确定这是否更快,但我们可以在这里尝试一些事情:
方法 1 np.intersect1d
带列表理解
[np.intersect1d(arr[0], arr[1]) for arr in list(zip(a,b))]
# Out
[array([], dtype=int32), array([2]), array([5]), array([], dtype=int32)]
或列出:
[np.intersect1d(arr[0], arr[1]).tolist() for arr in list(zip(a,b))]
# Out
[[], [2], [5], []]
方法 2 set
含列表理解:
[list(set(arr[0]) & set(arr[1])) for arr in list(zip(a,b))]
# Out
[[], [2], [5], []]
方法 #1
这是基于
# Sort each row of a and b in-place
a.sort(1)
b.sort(1)
# Use 2D searchsorted row-wise between a and b
idx = searchsorted2d(a,b)
# "Clip-out" out of bounds indices
idx[idx==a.shape[1]] = 0
# Get mask of valid ones i.e. matches
mask = np.take_along_axis(a,idx,axis=1)==b
# Use argmax to get first match as we know there's at most one match
match_val = np.take_along_axis(b,mask.argmax(1)[:,None],axis=1)
# Finally use np.where to choose between valid match
# (decided by any one True in each row of mask)
out = np.where(mask.any(1)[:,None],match_val,np.nan)
方法 #2
Numba-based 一个为了内存效率 -
from numba import njit
@njit(parallel=True)
def numba_f1(a,b,out):
n,a_ncols = a.shape
b_ncols = b.shape[1]
for i in range(n):
for j in range(a_ncols):
for k in range(b_ncols):
m = a[i,j]==b[i,k]
if m:
break
if m:
out[i] = a[i,j]
break
return out
def find_first_common_elem_per_row(a,b):
out = np.full(len(a),np.nan)
numba_f1(a,b,out)
return out
方法 #3
这是另一个基于堆叠和排序的矢量化 -
r = np.arange(len(a))
ab = np.hstack((a,b))
idx = ab.argsort(1)
ab_s = ab[r[:,None],idx]
m = ab_s[:,:-1] == ab_s[:,1:]
m2 = (idx[:,1:]*m)>=a.shape[1]
m3 = m & m2
out = np.where(m3.any(1),b[r,idx[r,m3.argmax(1)+1]-a.shape[1]],np.nan)
方法 #4
为了优雅,我们可以使用 broadcasting
作为 resource-hungry 方法 -
m = (a[:,None]==b[:,:,None]).any(2)
out = np.where(m.any(1),b[np.arange(len(a)),m.argmax(1)],np.nan)
做一些研究,我发现在 O(n+m) 中检查两个列表是否不相交 运行s,从而 n 和 m 是列表的长度(参见 here)。这个想法是在哈希映射的恒定时间内插入和查找元素运行。因此,将第一个列表中的所有元素插入哈希映射需要 O(n) 次操作,检查第二个列表中的每个元素是否已经在哈希映射中需要 O(m) 操作。因此,基于排序的解决方案 运行 在 O(n log(n) + m log(m)) 中不是渐近最优的。
尽管@Divakar 的解决方案在许多用例中都非常高效,但如果第二维很大,它们的效率就会降低。那么,基于哈希映射的解决方案更适合。我在cython:
中实现如下import numpy as np
cimport numpy as np
import cython
from libc.math cimport NAN
from libcpp.unordered_map cimport unordered_map
np.import_array()
@cython.boundscheck(False)
@cython.wraparound(False)
def get_common_element2d(np.ndarray[double, ndim=2] arr1,
np.ndarray[double, ndim=2] arr2):
cdef np.ndarray[double, ndim=1] result = np.empty(arr1.shape[0])
cdef int dim1 = arr1.shape[1]
cdef int dim2 = arr2.shape[1]
cdef int i, j
cdef unordered_map[double, int] tmpset = unordered_map[double, int]()
for i in range(arr1.shape[0]):
for j in range(dim1):
# insert arr1[i, j] as key without assigned value
tmpset[arr1[i, j]]
for j in range(dim2):
# check whether arr2[i, j] is in tmpset
if tmpset.count(arr2[i,j]):
result[i] = arr2[i,j]
break
else:
result[i] = NAN
tmpset.clear()
return result
我创建了如下测试用例:
import numpy as np
import timeit
from itertools import starmap
from mycythonmodule import get_common_element2d
m, n = 3000, 3000
a = np.random.rand(m, n)
b = np.random.rand(m, n)
for i, row in enumerate(a):
if np.random.randint(2):
common = np.random.choice(row, 1)
b[i][np.random.choice(np.arange(n), np.random.randint(min(n,20)), False)] = common
# we need to copy the arrays on each test run, otherwise they
# will remain sorted, which would bias the results
%timeit [set(aa).intersection(bb) for aa, bb in zip(a.copy(), b.copy())]
# returns 3.11 s ± 56.8 ms
%timeit list(starmap(np.intersect1d, zip(a.copy(), b.copy)))
# returns 1.83 s ± 55.4
# test sorting method
# divakarsMethod1 is the appraoch #1 in @Divakar's answer
%timeit divakarsMethod1(a.copy(), b.copy())
# returns 1.88 s ± 18 ms
# test hash map method
%timeit get_common_element2d(a.copy(), b.copy())
# returns 1.46 s ± 22.6 ms
这些结果似乎表明朴素的方法实际上比某些矢量化版本更好。但是,如果考虑具有较少列的许多行(不同的用例),则矢量化算法会发挥其优势。在这些情况下,矢量化方法比朴素方法快 5 倍以上,结果证明排序方法是最好的。
结论: 我将使用 HashMap-based cython 版本,因为它是两种用例中最有效的变体之一。如果我必须先设置 cython,我会使用 sorting-based 方法。