枚举两个大数组的快速方法?
Quick method to enumerate two big arrays?
我有两个大数组要处理。但是让我们看一下下面的简化示例来理解这个想法:
我想查找 data1
中的元素是否与 data2
和 return 中的元素匹配 data1
和 [=13] 中的数组索引=] 如果找到新数组形式的匹配 [index of data1, index of data2]
。例如,对于下面的一组 data1
和 data2
,程序将 return:
data1 = [[1,1],[2,5],[623,781]]
data2 = [[1,1], [161,74],[357,17],[1,1]]
expected_output = [[0,0],[0,3]]
我目前的代码如下:
result = []
for index, item in enumerate(data1):
for index2,item2 in enumerate(data2):
if np.array_equal(item,item2):
result.append([index,index2])
>>> result
[[0, 0], [0, 3]]
这很好用。但是,我正在处理的实际两个数组各有 60 万个项目。上面的代码会非常慢。有什么方法可以加快这个过程吗?
可能不是最快的,但简单且相当快:使用 KDTrees:
>>> data1 = [[1,1],[2,5],[623,781]]
>>> data2 = [[1,1], [161,74],[357,17],[1,1]]
>>>
>>> from operator import itemgetter
>>> from scipy.spatial import cKDTree as KDTree
>>>
>>> def intersect(a, b):
... A = KDTree(a); B = KDTree(b); X = A.query_ball_tree(B, 0.5)
... ai, bi = zip(*filter(itemgetter(1), enumerate(X)))
... ai = np.repeat(ai, np.fromiter(map(len, bi), int, len(ai)))
... bi = np.concatenate(bi)
... return ai, bi
...
>>> intersect(data1, data2)
(array([0, 0]), array([0, 3]))
两个假数据集 1,000,000
对每个需要 3
秒:
>>> from time import perf_counter
>>>
>>> a = np.random.randint(0, 100000, (1000000, 2))
>>> b = np.random.randint(0, 100000, (1000000, 2))
>>> t = perf_counter(); intersect(a, b); s = perf_counter()
(array([ 971, 3155, 15034, 35844, 41173, 60467, 73758, 91585,
97136, 105296, 121005, 121658, 124142, 126111, 133593, 141889,
150299, 165881, 167420, 174844, 179410, 192858, 222345, 227722,
233547, 234932, 243683, 248863, 255784, 264908, 282948, 282951,
285346, 287276, 302142, 318933, 327837, 328595, 332435, 342289,
344780, 350286, 355322, 370691, 377459, 401086, 412310, 415688,
442978, 461111, 469857, 491504, 493915, 502945, 506983, 507075,
511610, 515631, 516080, 532457, 541138, 546281, 550592, 551751,
554482, 568418, 571825, 591491, 594428, 603048, 639900, 648278,
666410, 672724, 708500, 712873, 724467, 740297, 740640, 749559,
752723, 761026, 777911, 790371, 791214, 793415, 795352, 801873,
811260, 815527, 827915, 848170, 861160, 892562, 909555, 918745,
924090, 929919, 933605, 939789, 940788, 940958, 950718, 950804,
997947]), array([507017, 972033, 787596, 531935, 590375, 460365, 17480, 392726,
552678, 545073, 128635, 590104, 251586, 340475, 330595, 783361,
981598, 677225, 80580, 38991, 304132, 157839, 980986, 881068,
308195, 162984, 618145, 68512, 58426, 190708, 123356, 568864,
583337, 128244, 106965, 528053, 626051, 391636, 868254, 296467,
39446, 791298, 356664, 428875, 143312, 356568, 736283, 902291,
5607, 475178, 902339, 312950, 891330, 941489, 93635, 884057,
329780, 270399, 633109, 106370, 626170, 54185, 103404, 658922,
108909, 641246, 711876, 496069, 835306, 745188, 328947, 975464,
522226, 746501, 642501, 489770, 859273, 890416, 62451, 463659,
884001, 980820, 171523, 222668, 203244, 149955, 134192, 369508,
905913, 839301, 758474, 114597, 534015, 381467, 7328, 447698,
651929, 137424, 975677, 758923, 982976, 778075, 95266, 213456,
210555]))
>>> print(s-t)
2.98617472499609
注意 使用字典(用于检查精确匹配)或 KDTree(用于 epsilon-close 匹配)的其他答案比这个好得多—速度更快,内存效率更高。
使用scipy.spatial.distance.cdist。如果您的两个数据数组各有 N
和 M
条目,它将生成一个 N
by M
成对距离数组。如果您可以将其放入 RAM,那么很容易找到匹配的索引:
import numpy as np
from scipy.spatial.distance import cdist
# Generate some data that's very likely to have repeats
a = np.random.randint(0, 100, (1000, 2))
b = np.random.randint(0, 100, (1000, 2))
# `cityblock` is likely the cheapest distance to calculate (no sqrt, etc.)
c = cdist(a, b, 'cityblock')
# And the indexes of all the matches:
aidx, bidx = np.nonzero(c == 0)
# sanity check:
print([(a[i], b[j]) for i,j in zip(aidx, bidx)])
以上打印出来:
[(array([ 0, 84]), array([ 0, 84])),
(array([50, 73]), array([50, 73])),
(array([53, 86]), array([53, 86])),
(array([96, 85]), array([96, 85])),
(array([95, 18]), array([95, 18])),
(array([ 4, 59]), array([ 4, 59])), ... ]
因为你的数据都是整数,所以可以用字典(hashtable),时间是0.55秒,和Paul的回答一样的数据。这不一定会找到 a
和 b
之间的所有配对副本(即,如果 a
和 b
本身包含重复项),但是修改它很容易做到或者在之后进行第二次传递(仅通过匹配的项目)以检查数据中这些向量的其他出现。
import numpy as np
def intersect1(a, b):
a_d = {}
for i, x in enumerate(a):
a_d[x] = i
for i, y in enumerate(b):
if y in a_d:
yield a_d[y], i
from time import perf_counter
a = list(tuple(x) for x in list(np.random.randint(0, 100000, (1000000, 2))))
b = list(tuple(x) for x in list(np.random.randint(0, 100000, (1000000, 2))))
t = perf_counter(); print(list(intersect1(a, b))); s = perf_counter()
print(s-t)
相比之下,Paul's 在我的机器上需要 2.46 秒。
我有两个大数组要处理。但是让我们看一下下面的简化示例来理解这个想法:
我想查找 data1
中的元素是否与 data2
和 return 中的元素匹配 data1
和 [=13] 中的数组索引=] 如果找到新数组形式的匹配 [index of data1, index of data2]
。例如,对于下面的一组 data1
和 data2
,程序将 return:
data1 = [[1,1],[2,5],[623,781]]
data2 = [[1,1], [161,74],[357,17],[1,1]]
expected_output = [[0,0],[0,3]]
我目前的代码如下:
result = []
for index, item in enumerate(data1):
for index2,item2 in enumerate(data2):
if np.array_equal(item,item2):
result.append([index,index2])
>>> result
[[0, 0], [0, 3]]
这很好用。但是,我正在处理的实际两个数组各有 60 万个项目。上面的代码会非常慢。有什么方法可以加快这个过程吗?
可能不是最快的,但简单且相当快:使用 KDTrees:
>>> data1 = [[1,1],[2,5],[623,781]]
>>> data2 = [[1,1], [161,74],[357,17],[1,1]]
>>>
>>> from operator import itemgetter
>>> from scipy.spatial import cKDTree as KDTree
>>>
>>> def intersect(a, b):
... A = KDTree(a); B = KDTree(b); X = A.query_ball_tree(B, 0.5)
... ai, bi = zip(*filter(itemgetter(1), enumerate(X)))
... ai = np.repeat(ai, np.fromiter(map(len, bi), int, len(ai)))
... bi = np.concatenate(bi)
... return ai, bi
...
>>> intersect(data1, data2)
(array([0, 0]), array([0, 3]))
两个假数据集 1,000,000
对每个需要 3
秒:
>>> from time import perf_counter
>>>
>>> a = np.random.randint(0, 100000, (1000000, 2))
>>> b = np.random.randint(0, 100000, (1000000, 2))
>>> t = perf_counter(); intersect(a, b); s = perf_counter()
(array([ 971, 3155, 15034, 35844, 41173, 60467, 73758, 91585,
97136, 105296, 121005, 121658, 124142, 126111, 133593, 141889,
150299, 165881, 167420, 174844, 179410, 192858, 222345, 227722,
233547, 234932, 243683, 248863, 255784, 264908, 282948, 282951,
285346, 287276, 302142, 318933, 327837, 328595, 332435, 342289,
344780, 350286, 355322, 370691, 377459, 401086, 412310, 415688,
442978, 461111, 469857, 491504, 493915, 502945, 506983, 507075,
511610, 515631, 516080, 532457, 541138, 546281, 550592, 551751,
554482, 568418, 571825, 591491, 594428, 603048, 639900, 648278,
666410, 672724, 708500, 712873, 724467, 740297, 740640, 749559,
752723, 761026, 777911, 790371, 791214, 793415, 795352, 801873,
811260, 815527, 827915, 848170, 861160, 892562, 909555, 918745,
924090, 929919, 933605, 939789, 940788, 940958, 950718, 950804,
997947]), array([507017, 972033, 787596, 531935, 590375, 460365, 17480, 392726,
552678, 545073, 128635, 590104, 251586, 340475, 330595, 783361,
981598, 677225, 80580, 38991, 304132, 157839, 980986, 881068,
308195, 162984, 618145, 68512, 58426, 190708, 123356, 568864,
583337, 128244, 106965, 528053, 626051, 391636, 868254, 296467,
39446, 791298, 356664, 428875, 143312, 356568, 736283, 902291,
5607, 475178, 902339, 312950, 891330, 941489, 93635, 884057,
329780, 270399, 633109, 106370, 626170, 54185, 103404, 658922,
108909, 641246, 711876, 496069, 835306, 745188, 328947, 975464,
522226, 746501, 642501, 489770, 859273, 890416, 62451, 463659,
884001, 980820, 171523, 222668, 203244, 149955, 134192, 369508,
905913, 839301, 758474, 114597, 534015, 381467, 7328, 447698,
651929, 137424, 975677, 758923, 982976, 778075, 95266, 213456,
210555]))
>>> print(s-t)
2.98617472499609
注意 使用字典(用于检查精确匹配)或 KDTree(用于 epsilon-close 匹配)的其他答案比这个好得多—速度更快,内存效率更高。
使用scipy.spatial.distance.cdist。如果您的两个数据数组各有 N
和 M
条目,它将生成一个 N
by M
成对距离数组。如果您可以将其放入 RAM,那么很容易找到匹配的索引:
import numpy as np
from scipy.spatial.distance import cdist
# Generate some data that's very likely to have repeats
a = np.random.randint(0, 100, (1000, 2))
b = np.random.randint(0, 100, (1000, 2))
# `cityblock` is likely the cheapest distance to calculate (no sqrt, etc.)
c = cdist(a, b, 'cityblock')
# And the indexes of all the matches:
aidx, bidx = np.nonzero(c == 0)
# sanity check:
print([(a[i], b[j]) for i,j in zip(aidx, bidx)])
以上打印出来:
[(array([ 0, 84]), array([ 0, 84])),
(array([50, 73]), array([50, 73])),
(array([53, 86]), array([53, 86])),
(array([96, 85]), array([96, 85])),
(array([95, 18]), array([95, 18])),
(array([ 4, 59]), array([ 4, 59])), ... ]
因为你的数据都是整数,所以可以用字典(hashtable),时间是0.55秒,和Paul的回答一样的数据。这不一定会找到 a
和 b
之间的所有配对副本(即,如果 a
和 b
本身包含重复项),但是修改它很容易做到或者在之后进行第二次传递(仅通过匹配的项目)以检查数据中这些向量的其他出现。
import numpy as np
def intersect1(a, b):
a_d = {}
for i, x in enumerate(a):
a_d[x] = i
for i, y in enumerate(b):
if y in a_d:
yield a_d[y], i
from time import perf_counter
a = list(tuple(x) for x in list(np.random.randint(0, 100000, (1000000, 2))))
b = list(tuple(x) for x in list(np.random.randint(0, 100000, (1000000, 2))))
t = perf_counter(); print(list(intersect1(a, b))); s = perf_counter()
print(s-t)
相比之下,Paul's 在我的机器上需要 2.46 秒。