numba @njit 更新一个大字典
numba @njit to update a big dict
我尝试将 numba 用于需要在非常大的 (10e6) dict 上以 (int, int) 元组作为键进行搜索的函数。
import numpy as np
from numba import njit
myarray = np.array([[0, 0], # 0, 1
[0, 1],
[1, 1], # 1, 2
[1, 2], # 1, 3
[2, 2],
[1, 3]]
) # a lot of this with shape~(10e6, 2)
dict_with_tuples_key = {(0, 1): 1,
(3, 7): 1} # ~10e6 keys
简化版看起来像这样
# @njit
def update_dict(dict_with_tuples_key, myarray):
for line in myarray:
i, j = line
if (i, j) in dict_with_tuples_key:
dict_with_tuples_key[(i, j)] += 1
else:
dict_with_tuples_key[(i, j)] = 1
return dict_with_tuples_key
new_dict = update_dict(dict_with_tuples_key, myarray)
print new_dict
new_dict = update_dict2(dict_with_tuples_key, myarray)
# print new_dict
# {(0, 1): 2, # +1 already in dict_with_tuples_key
# (0, 0): 1, # diag
# (1, 1): 1, # diag
# (2, 2): 1, # diag
# (1, 2): 1, # new from myarray
# (1, 3): 1, # new from myarray
# (3, 7): 1 }
@njit 似乎不接受 dict 作为函数参数?
我想知道如何重写这个,特别是进行搜索的 if (i, j) in dict_with_tuples_key
部分。
作为替代方案,如果速度足够快,您可以尝试:
from collections import Counter
c2 = Counter(dict_with_tuples_key)
c1 = Counter(tuple(x) for x in myarray)
new_dict = dict(c1 + c2)
njit
表示该函数是以nopython
方式编译的。 dict
、list
和 tuple
是 python 对象,因此不受支持。不作为参数,也不在函数内部。
如果您的字典键完全不同,我会考虑使用二维 numpy 数组,其中第一个轴表示字典键元组的第一个索引,第二个轴表示第二个索引。然后你可以将其重写为:
from numba import njit
import numpy as np
@njit
def update_array(array, myarray):
elements = myarray.shape[0]
for i in range(elements):
array[myarray[i][0]][myarray[i][1]] += 1
return array
myarray = np.array([[0, 0], [0, 1], [1, 1],
[1, 2], [2, 2], [1, 3]])
# Calculate the size of the numpy array that replaces the dict:
lens = np.max(myarray, axis=0) # Maximum values
array = np.zeros((lens[0]+1, lens[1]+1)) # Create an empty array to hold all indexes in myarray
update_array(array, myarray)
既然你已经用元组索引了你的字典,那么索引数组的转换问题就不会太大了。
我尝试将 numba 用于需要在非常大的 (10e6) dict 上以 (int, int) 元组作为键进行搜索的函数。
import numpy as np
from numba import njit
myarray = np.array([[0, 0], # 0, 1
[0, 1],
[1, 1], # 1, 2
[1, 2], # 1, 3
[2, 2],
[1, 3]]
) # a lot of this with shape~(10e6, 2)
dict_with_tuples_key = {(0, 1): 1,
(3, 7): 1} # ~10e6 keys
简化版看起来像这样
# @njit
def update_dict(dict_with_tuples_key, myarray):
for line in myarray:
i, j = line
if (i, j) in dict_with_tuples_key:
dict_with_tuples_key[(i, j)] += 1
else:
dict_with_tuples_key[(i, j)] = 1
return dict_with_tuples_key
new_dict = update_dict(dict_with_tuples_key, myarray)
print new_dict
new_dict = update_dict2(dict_with_tuples_key, myarray)
# print new_dict
# {(0, 1): 2, # +1 already in dict_with_tuples_key
# (0, 0): 1, # diag
# (1, 1): 1, # diag
# (2, 2): 1, # diag
# (1, 2): 1, # new from myarray
# (1, 3): 1, # new from myarray
# (3, 7): 1 }
@njit 似乎不接受 dict 作为函数参数?
我想知道如何重写这个,特别是进行搜索的 if (i, j) in dict_with_tuples_key
部分。
作为替代方案,如果速度足够快,您可以尝试:
from collections import Counter
c2 = Counter(dict_with_tuples_key)
c1 = Counter(tuple(x) for x in myarray)
new_dict = dict(c1 + c2)
njit
表示该函数是以nopython
方式编译的。 dict
、list
和 tuple
是 python 对象,因此不受支持。不作为参数,也不在函数内部。
如果您的字典键完全不同,我会考虑使用二维 numpy 数组,其中第一个轴表示字典键元组的第一个索引,第二个轴表示第二个索引。然后你可以将其重写为:
from numba import njit
import numpy as np
@njit
def update_array(array, myarray):
elements = myarray.shape[0]
for i in range(elements):
array[myarray[i][0]][myarray[i][1]] += 1
return array
myarray = np.array([[0, 0], [0, 1], [1, 1],
[1, 2], [2, 2], [1, 3]])
# Calculate the size of the numpy array that replaces the dict:
lens = np.max(myarray, axis=0) # Maximum values
array = np.zeros((lens[0]+1, lens[1]+1)) # Create an empty array to hold all indexes in myarray
update_array(array, myarray)
既然你已经用元组索引了你的字典,那么索引数组的转换问题就不会太大了。