这个 numba 函数的错误是什么?
What is the error in this numba function about?
我写了这个 python 函数,我相信它会移植到 numba。不幸的是它没有,我不确定我是否理解错误:
Invalid use of getiter with parameters (none)
.
它需要知道发电机的类型吗?是不是因为它returns个变长元组?
from numba import njit
# @njit
def iterator(N, k):
r"""Numba implementation of an iterator over tuples of N integers,
such that sum(tuple) == k.
Args:
N (int): number of elements in the tuple
k (int): sum of the elements
Returns:
tuple(int): a tuple of N integers
"""
if N == 1:
yield (k,)
else:
for i in range(k+1):
for j in iterator(N-1, k-i):
yield (i,) + j
编辑
感谢杰罗姆提供的提示。这是我最终写的解决方案(我从左边开始):
import numpy as np
from numba import njit
@njit
def next_lst(lst, i, reset=False):
r"""Computes the next list of indices given the current list
and the current index.
"""
if lst[i] == 0:
return next_lst(lst, i+1, reset=True)
else:
lst[i] -= 1
lst[i+1] += 1
if reset:
lst[0] = np.sum(lst[:i+1])
lst[1:i+1] = 0
i = 0
return lst, i
@njit
def generator(N, k):
r"""Goes through all the lists of indices recursively.
"""
lst = np.zeros(N, dtype=np.int64)
lst[0] = k
i = 0
yield lst
while lst[-1] < k:
lst, i = next_lst(lst, i)
yield lst
这给出了正确的结果,而且它很成功!
for lst in generator(4,2):
print(lst)
[2 0 0 0]
[1 1 0 0]
[0 2 0 0]
[1 0 1 0]
[0 1 1 0]
[0 0 2 0]
[1 0 0 1]
[0 1 0 1]
[0 0 1 1]
[0 0 0 2]
一个问题来自 variable-sized 元组 输出。实际上,元组就像 Numba 中具有不同类型的结构。它们与列表非常不同,而不是 Python(AFAIK,在 Python 中,元组大致就是无法更改的列表)。在 Numba 中,1 项和 2 项的元组是两种不同的类型。它们不能统一到更通用的类型。问题是函数的 return 值必须是唯一类型。因此,Numba 拒绝在 nopython 模式下编译函数。在 Numba 中解决这个问题的唯一方法是使用列表。
话虽这么说,但即使有list,也报错。文档指出:
Most recursive call patterns are supported. The only restriction is that the recursive callee must have a control-flow path that returns without recursing.
我认为这里没有满足此限制,因为没有 return 声明。话虽这么说,该函数应该隐式 return 一个生成器(其类型取决于...递归函数本身)。还要注意 support of generators 是很新的,递归生成器没有得到很好的支持似乎是合理的。我建议您在 Numba github 上提出一个问题,因为我不确定这是预期的行为。
请注意,不使用递归实现此功能可能效率更高。顺便说一下,只有从 Numba 函数而不是 CPython.
调用这个函数才会更快
我写了这个 python 函数,我相信它会移植到 numba。不幸的是它没有,我不确定我是否理解错误:
Invalid use of getiter with parameters (none)
.
它需要知道发电机的类型吗?是不是因为它returns个变长元组?
from numba import njit
# @njit
def iterator(N, k):
r"""Numba implementation of an iterator over tuples of N integers,
such that sum(tuple) == k.
Args:
N (int): number of elements in the tuple
k (int): sum of the elements
Returns:
tuple(int): a tuple of N integers
"""
if N == 1:
yield (k,)
else:
for i in range(k+1):
for j in iterator(N-1, k-i):
yield (i,) + j
编辑
感谢杰罗姆提供的提示。这是我最终写的解决方案(我从左边开始):
import numpy as np
from numba import njit
@njit
def next_lst(lst, i, reset=False):
r"""Computes the next list of indices given the current list
and the current index.
"""
if lst[i] == 0:
return next_lst(lst, i+1, reset=True)
else:
lst[i] -= 1
lst[i+1] += 1
if reset:
lst[0] = np.sum(lst[:i+1])
lst[1:i+1] = 0
i = 0
return lst, i
@njit
def generator(N, k):
r"""Goes through all the lists of indices recursively.
"""
lst = np.zeros(N, dtype=np.int64)
lst[0] = k
i = 0
yield lst
while lst[-1] < k:
lst, i = next_lst(lst, i)
yield lst
这给出了正确的结果,而且它很成功!
for lst in generator(4,2):
print(lst)
[2 0 0 0]
[1 1 0 0]
[0 2 0 0]
[1 0 1 0]
[0 1 1 0]
[0 0 2 0]
[1 0 0 1]
[0 1 0 1]
[0 0 1 1]
[0 0 0 2]
一个问题来自 variable-sized 元组 输出。实际上,元组就像 Numba 中具有不同类型的结构。它们与列表非常不同,而不是 Python(AFAIK,在 Python 中,元组大致就是无法更改的列表)。在 Numba 中,1 项和 2 项的元组是两种不同的类型。它们不能统一到更通用的类型。问题是函数的 return 值必须是唯一类型。因此,Numba 拒绝在 nopython 模式下编译函数。在 Numba 中解决这个问题的唯一方法是使用列表。
话虽这么说,但即使有list,也报错。文档指出:
Most recursive call patterns are supported. The only restriction is that the recursive callee must have a control-flow path that returns without recursing.
我认为这里没有满足此限制,因为没有 return 声明。话虽这么说,该函数应该隐式 return 一个生成器(其类型取决于...递归函数本身)。还要注意 support of generators 是很新的,递归生成器没有得到很好的支持似乎是合理的。我建议您在 Numba github 上提出一个问题,因为我不确定这是预期的行为。
请注意,不使用递归实现此功能可能效率更高。顺便说一下,只有从 Numba 函数而不是 CPython.
调用这个函数才会更快