在 nopython 模式下使用 Numba 的递归函数错误
Error in recursive function with Numba in nopython mode
我想 运行 Numba 中的递归函数,使用 nopython 模式。直到现在我才收到错误。这是一个非常简单的代码,用户给出一个少于五个元素的元组,然后该函数创建另一个元组,并向该元组添加一个新值(在本例中为数字 3)。重复此操作,直到最终元组的长度为 5。由于某种原因,这不起作用,不知道为什么。
@njit
def tup(a):
if len(a) == 5:
return a
else:
b = a + (3,)
b = tup(b)
return b
例如,如果 a = (0,1)
,我希望最终结果是元组 (0,1,3,3,3)
。
编辑:我正在使用 Numba 0.41.0,我得到的错误是内核快死了,'The kernel appears to have died. It will restart automatically.'
根据当前版本中的this list of proposals:
Recursion support in numba is currently limited to self-recursion with
explicit type annotation for the function. This limitation comes from
the inability to determine the return type of a recursive call.
所以,试试:
from numba import jit
@jit()
def tup(a:tuple) -> tuple:
if len(a) == 5:
return a
return tup(a + (3,))
print(tup((0, 1)))
看看这是否对你更有效。
您不应该这样做的原因有几个:
- 这通常是一种方法,在纯 Python 中可能比在 numba 装饰函数中更快。
- 迭代会更简单并且可能更快,但要注意连接元组通常是一个
O(n)
操作,即使在 numba 中也是如此。所以函数的整体性能将是O(n**2)
。这可以通过使用支持 O(1)
附加的数据结构或支持预分配大小的数据结构来改进。或者干脆不使用 "loopy" 或 "recursive" 方法。
- 您是否尝试过如果省略
njit
装饰器并传入包含 6 个元素的元组会发生什么情况? (提示:它将达到递归限制,因为它永远不会满足递归的结束条件)。
Numba,在编写 0.43.1 时,仅在参数类型在递归之间不改变时支持简单递归。在您的情况下,类型确实发生了变化,您传入了一个 tuple(int64 x 2)
但递归调用试图传入一个不同类型的 tuple(int64 x 3)
。奇怪的是,它在我的电脑上遇到了 Whosebug
- 这似乎是 numba 中的一个错误。
我的建议是使用这个(没有 numba,没有递归):
def tup(a):
if len(a) < 5:
a += (3, ) * (5 - len(a))
return a
这也是returns预期的结果:
>>> tup((1,))
(1, 3, 3, 3, 3)
>>> tup((1, 2))
(1, 2, 3, 3, 3)
我想 运行 Numba 中的递归函数,使用 nopython 模式。直到现在我才收到错误。这是一个非常简单的代码,用户给出一个少于五个元素的元组,然后该函数创建另一个元组,并向该元组添加一个新值(在本例中为数字 3)。重复此操作,直到最终元组的长度为 5。由于某种原因,这不起作用,不知道为什么。
@njit
def tup(a):
if len(a) == 5:
return a
else:
b = a + (3,)
b = tup(b)
return b
例如,如果 a = (0,1)
,我希望最终结果是元组 (0,1,3,3,3)
。
编辑:我正在使用 Numba 0.41.0,我得到的错误是内核快死了,'The kernel appears to have died. It will restart automatically.'
根据当前版本中的this list of proposals:
Recursion support in numba is currently limited to self-recursion with explicit type annotation for the function. This limitation comes from the inability to determine the return type of a recursive call.
所以,试试:
from numba import jit
@jit()
def tup(a:tuple) -> tuple:
if len(a) == 5:
return a
return tup(a + (3,))
print(tup((0, 1)))
看看这是否对你更有效。
您不应该这样做的原因有几个:
- 这通常是一种方法,在纯 Python 中可能比在 numba 装饰函数中更快。
- 迭代会更简单并且可能更快,但要注意连接元组通常是一个
O(n)
操作,即使在 numba 中也是如此。所以函数的整体性能将是O(n**2)
。这可以通过使用支持O(1)
附加的数据结构或支持预分配大小的数据结构来改进。或者干脆不使用 "loopy" 或 "recursive" 方法。 - 您是否尝试过如果省略
njit
装饰器并传入包含 6 个元素的元组会发生什么情况? (提示:它将达到递归限制,因为它永远不会满足递归的结束条件)。
Numba,在编写 0.43.1 时,仅在参数类型在递归之间不改变时支持简单递归。在您的情况下,类型确实发生了变化,您传入了一个 tuple(int64 x 2)
但递归调用试图传入一个不同类型的 tuple(int64 x 3)
。奇怪的是,它在我的电脑上遇到了 Whosebug
- 这似乎是 numba 中的一个错误。
我的建议是使用这个(没有 numba,没有递归):
def tup(a):
if len(a) < 5:
a += (3, ) * (5 - len(a))
return a
这也是returns预期的结果:
>>> tup((1,))
(1, 3, 3, 3, 3)
>>> tup((1, 2))
(1, 2, 3, 3, 3)