在 nopython 模式下使用 Numba 的递归函数错误

Error in recursive function with Numba in nopython mode

我想 运行 Numba 中的递归函数,使用 nopython 模式。直到现在我才收到错误。这是一个非常简单的代码,用户给出一个少于五个元素的元组,然后该函数创建另一个元组,并向该元组添加一个新值(在本例中为数字 3)。重复此操作,直到最终元组的长度为 5。由于某种原因,这不起作用,不知道为什么。

@njit
def tup(a):
    if len(a) == 5:
        return a
    else:
        b = a + (3,)
        b = tup(b)
        return b

例如,如果 a = (0,1),我希望最终结果是元组 (0,1,3,3,3)

编辑:我正在使用 Numba 0.41.0,我得到的错误是内核快死了,'The kernel appears to have died. It will restart automatically.'

根据当前版本中的this list of proposals

Recursion support in numba is currently limited to self-recursion with explicit type annotation for the function. This limitation comes from the inability to determine the return type of a recursive call.

所以,试试:

from numba import jit

@jit()
def tup(a:tuple) -> tuple:
    if len(a) == 5:
        return a

    return tup(a + (3,))

print(tup((0, 1)))

看看这是否对你更有效。

您不应该这样做的原因有几个:

  • 这通常是一种方法,在纯 Python 中可能比在 numba 装饰函数中更快。
  • 迭代会更简单并且可能更快,但要注意连接元组通常是一个 O(n) 操作,即使在 numba 中也是如此。所以函数的整体性能将是O(n**2)。这可以通过使用支持 O(1) 附加的数据结构或支持预分配大小的数据结构来改进。或者干脆不使用 "loopy" 或 "recursive" 方法。
  • 您是否尝试过如果省略 njit 装饰器并传入包含 6 个元素的元组会发生什么情况? (提示:它将达到递归限制,因为它永远不会满足递归的结束条件)。

Numba,在编写 0.43.1 时,仅在参数类型在递归之间不改变时支持简单递归。在您的情况下,类型确实发生了变化,您传入了一个 tuple(int64 x 2) 但递归调用试图传入一个不同类型的 tuple(int64 x 3) 。奇怪的是,它在我的电脑上遇到了 Whosebug - 这似乎是 numba 中的一个错误。

我的建议是使用这个(没有 numba,没有递归):

def tup(a):
    if len(a) < 5:
        a += (3, ) * (5 - len(a))
    return a

这也是returns预期的结果:

>>> tup((1,))
(1, 3, 3, 3, 3)
>>> tup((1, 2))
(1, 2, 3, 3, 3)