Numba:使用具有默认值的参数调用具有显式签名的 jit

Numba: calling jit with explicit signature using arguments with default values

我正在使用 numba 在 numpy 数组上创建一些包含循环的函数。

一切都很好,我可以使用 jit 并且我学会了如何定义签名。

现在我尝试在带有可选参数的函数上使用 jit,例如:

from numba import jit
import numpy as np

@jit(['float64(float64, float64)', 'float64(float64, optional(float))'])
def fun(a, b=3):
    return a + b

这行得通,但如果我使用 optional(float64) 而不是 optional(float),则行不通(与 intint64 相同)。我花了 1 个小时试图弄清楚这个语法(实际上,我的一个朋友偶然发现了这个解决方案,因为他忘了在 float 后面写 64),但是,看在我的份上,我不明白为什么是这样的。我在互联网上找不到任何东西,numba 关于该主题的文档充其量也很少(他们指定 optional 应该采用 numba 类型)。

有人知道这是怎么回事吗?我错过了什么?

嗯,不过异常信息应该有提示:

from numba import jit
import numpy as np

@jit(['float64(float64, float64)', 'float64(float64, optional(float64))'])
def fun(a, b=3.):
    return a + b

>>> fun(10.)
TypeError: No matching definition for argument type(s) float64, omitted(default=3.0)

这意味着 optional 在这里是错误的选择。事实上optional represents None or "that type"。但是您需要一个可选参数,而不是可以是 floatNone 的参数,例如:

>>> fun(10, None)  # doesn't fail because of the signature!
TypeError: unsupported operand type(s) for +: 'int' and 'NoneType'

我怀疑它只是 "happens" 为 optional(float) 工作,因为 float 从 numbas 的角度来看只是一个 "arbitary Python object",所以 optional(float)你可以在那里传递 anything (这显然包括不给出参数)。对于 optional(float64),它只能是 Nonefloat64。该类别不够广泛,不允许不提供参数

如果你给出类型 Omitted:

from numba import jit
import numpy as np

@jit(['float64(float64, float64)', 'float64(float64, Omitted(float64))'])
def fun(a, b=3.):
    return a + b

>>> fun(10.)
13.0

然而,Omitted 似乎实际上并未包含在文档中,它有一些 "rough edges"。例如,它不能用那个签名在 nopython 模式下编译,即使没有签名似乎是可能的:

@njit(['float64(float64, float64)', 'float64(float64, Omitted(float64))'])
def fun(a, b=3):
    return a + b

TypingError: Failed at nopython (nopython frontend)
Invalid usage of + with parameters (float64, class(float64))

-----------

@njit(['float64(float64, float64)', 'float64(float64, Omitted(3.))'])
def fun(a, b=3):
    return a + b

>>> fun(10.)
TypeError: No matching definition for argument type(s) float64, omitted(default=3)

-----------

@njit
def fun(a, b=3):
    return a + b

>>> fun(10.)
13.0