我怎样才能用 Numba 加速 python 的字典
How can I speed up python's dictionary with Numba
我需要在布尔值数组中存储一些单元格。起初我使用 numpy,但是当数组开始占用大量内存时,我想到了将非零元素存储在以元组为键的字典中(因为它是可散列类型)。例如:
{(0, 0, 0): True, (1, 2, 3): True}
(这是“3D 数组”中的两个单元格,索引为 0,0,0 和 1,2,3,但维数事先未知,并且在我 运行 我的算法时定义)。
它有很大帮助,因为非零单元格只填充数组的一小部分。
为了从这个字典中写入和获取值,我需要使用循环:
def fill_cells(indices, area_dict):
for i in indices:
area_dict[tuple(i)] = 1
def get_cells(indices, area_dict):
n = len(indices)
out = np.zeros(n, dtype=np.bool)
for i in range(n):
out[i] = tuple(indices[i]) in area_dict.keys()
return out
现在我需要用 Numba 加速它。 Numba 不支持原生 Python 的 dict(),所以我使用了 numba.typed.Dict。
问题是 Numba 想在定义函数的阶段知道键的大小,所以我什至无法创建字典(键的长度事先未知并在我调用函数时定义):
@njit
def make_dict(n):
out = {(0,)*n:True}
return out
Numba 无法正确推断字典键的类型并且 returns 错误:
Compilation is falling back to object mode WITH looplifting enabled because Function "make_dict" failed type inference due to: Invalid use of Function(<built-in function mul>) with argument(s) of type(s): (tuple(int64 x 1), int64)
如果我在函数中将 n 更改为具体数字,就可以了。我用这个技巧解决了它:
n = 10
s = '@njit\ndef make_dict():\n\tout = {(0,)*%s:True}\n\treturn out' % n
exec(s)
但我认为这是错误的低效方式。我需要将我的 fill_cells 和 get_cells 函数与@njit 装饰器一起使用,但是 Numba returns 出现相同的错误,因为我试图在这个函数中从 numpy 数组创建元组。
我了解 Numba 的基本局限性(以及一般的编译),但也许有一些方法可以加快功能,或者,也许您有其他解决我的单元格存储问题的方法?
最终解决方案:
主要问题是 Numba 在定义创建它的函数时需要知道元组的长度。诀窍是每次都重新定义函数。我需要使用定义函数的代码生成字符串,并使用 exec():
运行
n = 10
s = '@njit\ndef arr_to_tuple(a):\n\treturn (' + ''.join('a[%i],' % i for i in range(n)) + ')'
exec(s)
之后我可以调用 arr_to_tuple(a) 来创建可以在另一个 @njit - 修饰函数中使用的元组。
例如创建元组键的空字典,需要解决的问题:
@njit
def make_empty_dict():
tpl = arr_to_tuple(np.array([0]*5))
out = {tpl:True}
del out[tpl]
return out
我在字典中写了一个元素,因为它是 Numba 推断类型的一种方式。
此外,我需要使用问题中描述的 fill_cells 和 get_cells 函数。这就是我用 Numba 重写它们的方式:
写元素。刚刚将 tuple() 更改为 arr_to_tuple():
@njit
def fill_cells_nb(indices, area_dict):
for i in range(len(indices)):
area_dict[arr_to_tuple(indices[i])] = True
从字典中获取元素需要一些令人毛骨悚然的代码:
@njit
def get_cells_nb(indices, area_dict):
n = len(indices)
out = np.zeros(n, dtype=np.bool_)
for i in range(n):
new_len = len(area_dict)
tpl = arr_to_tuple(indices[i])
area_dict[tpl] = True
old_len = len(area_dict)
if new_len == old_len:
out[i] = True
else:
del area_dict[tpl]
return out
我的 Numba 版本 (0.46) 不支持 .contains (in) 运算符和 try-except 构造。如果您有支持它的版本,您可以为它编写更多 "regular" 解决方案。
所以当我想检查字典中是否存在具有某些索引的元素时,我会记住它的长度,然后在字典中写一些带有提到的索引的东西。如果长度改变了,我断定该元素不存在。否则该元素存在。看起来解决方案很慢,但事实并非如此。
速度测试:
解决方案的工作速度出奇的快。我用 %timeit 与 native-Python 优化代码进行了比较:
- arr_to_tuple() 比常规 tuple() 函数快 5 倍
- get_cells with numba 与 native-[=76= 相比,单个元素快 3 倍,大元素数组快 40 倍]写成get_cells
- fill_cells with numba 与 native-[=76= 相比,单个元素快 4 倍,大元素数组快 40 倍]写成fill_cells
我需要在布尔值数组中存储一些单元格。起初我使用 numpy,但是当数组开始占用大量内存时,我想到了将非零元素存储在以元组为键的字典中(因为它是可散列类型)。例如:
{(0, 0, 0): True, (1, 2, 3): True}
(这是“3D 数组”中的两个单元格,索引为 0,0,0 和 1,2,3,但维数事先未知,并且在我 运行 我的算法时定义)。
它有很大帮助,因为非零单元格只填充数组的一小部分。
为了从这个字典中写入和获取值,我需要使用循环:
def fill_cells(indices, area_dict):
for i in indices:
area_dict[tuple(i)] = 1
def get_cells(indices, area_dict):
n = len(indices)
out = np.zeros(n, dtype=np.bool)
for i in range(n):
out[i] = tuple(indices[i]) in area_dict.keys()
return out
现在我需要用 Numba 加速它。 Numba 不支持原生 Python 的 dict(),所以我使用了 numba.typed.Dict。 问题是 Numba 想在定义函数的阶段知道键的大小,所以我什至无法创建字典(键的长度事先未知并在我调用函数时定义):
@njit
def make_dict(n):
out = {(0,)*n:True}
return out
Numba 无法正确推断字典键的类型并且 returns 错误:
Compilation is falling back to object mode WITH looplifting enabled because Function "make_dict" failed type inference due to: Invalid use of Function(<built-in function mul>) with argument(s) of type(s): (tuple(int64 x 1), int64)
如果我在函数中将 n 更改为具体数字,就可以了。我用这个技巧解决了它:
n = 10
s = '@njit\ndef make_dict():\n\tout = {(0,)*%s:True}\n\treturn out' % n
exec(s)
但我认为这是错误的低效方式。我需要将我的 fill_cells 和 get_cells 函数与@njit 装饰器一起使用,但是 Numba returns 出现相同的错误,因为我试图在这个函数中从 numpy 数组创建元组。
我了解 Numba 的基本局限性(以及一般的编译),但也许有一些方法可以加快功能,或者,也许您有其他解决我的单元格存储问题的方法?
最终解决方案:
主要问题是 Numba 在定义创建它的函数时需要知道元组的长度。诀窍是每次都重新定义函数。我需要使用定义函数的代码生成字符串,并使用 exec():
运行n = 10
s = '@njit\ndef arr_to_tuple(a):\n\treturn (' + ''.join('a[%i],' % i for i in range(n)) + ')'
exec(s)
之后我可以调用 arr_to_tuple(a) 来创建可以在另一个 @njit - 修饰函数中使用的元组。
例如创建元组键的空字典,需要解决的问题:
@njit
def make_empty_dict():
tpl = arr_to_tuple(np.array([0]*5))
out = {tpl:True}
del out[tpl]
return out
我在字典中写了一个元素,因为它是 Numba 推断类型的一种方式。
此外,我需要使用问题中描述的 fill_cells 和 get_cells 函数。这就是我用 Numba 重写它们的方式:
写元素。刚刚将 tuple() 更改为 arr_to_tuple():
@njit
def fill_cells_nb(indices, area_dict):
for i in range(len(indices)):
area_dict[arr_to_tuple(indices[i])] = True
从字典中获取元素需要一些令人毛骨悚然的代码:
@njit
def get_cells_nb(indices, area_dict):
n = len(indices)
out = np.zeros(n, dtype=np.bool_)
for i in range(n):
new_len = len(area_dict)
tpl = arr_to_tuple(indices[i])
area_dict[tpl] = True
old_len = len(area_dict)
if new_len == old_len:
out[i] = True
else:
del area_dict[tpl]
return out
我的 Numba 版本 (0.46) 不支持 .contains (in) 运算符和 try-except 构造。如果您有支持它的版本,您可以为它编写更多 "regular" 解决方案。
所以当我想检查字典中是否存在具有某些索引的元素时,我会记住它的长度,然后在字典中写一些带有提到的索引的东西。如果长度改变了,我断定该元素不存在。否则该元素存在。看起来解决方案很慢,但事实并非如此。
速度测试:
解决方案的工作速度出奇的快。我用 %timeit 与 native-Python 优化代码进行了比较:
- arr_to_tuple() 比常规 tuple() 函数快 5 倍
- get_cells with numba 与 native-[=76= 相比,单个元素快 3 倍,大元素数组快 40 倍]写成get_cells
- fill_cells with numba 与 native-[=76= 相比,单个元素快 4 倍,大元素数组快 40 倍]写成fill_cells