为什么依赖 numba jitt'ed 函数的顺序很重要?

Why does ordering of dependent numba jitt'ed functions matter?

在python中,您可以定义多个以任意顺序相互调用的函数,并且在运行时将调用这些函数。这些函数在脚本中定义的顺序并不重要,只要它们存在即可。例如,以下是有效的并且可以工作

import numpy as np

def func1(arr):
    out = np.empty_like(arr)
    for i in range(arr.shape[0]):
        out[i] = func2(arr[i])  # calling func2 here which is defined below
    return out

def func2(a):
    out = a + 1
    return out

func1 可以调用 func2,即使 func2 是在 func1.

之后定义的

但是,如果我用 numba 修饰这些函数,我会得到一个错误

import numpy as np
import numba as nb


@nb.jit("f8[:](f8[:])", nopython=True)
def func1(arr):
    out = np.empty_like(arr)
    for i in range(arr.shape[0]):
        out[i] = func2(arr[i])
    return out

@nb.jit("f8(f8)", nopython=True)
def func2(a):
    out = a + 1
    return out

>>> TypingError: Failed in nopython mode pipeline (step: nopython frontend)
    Untyped global name 'func2': cannot determine Numba type of <class 
    'numba.ir.UndefinedType'>

所以 numba 在使用 JIT 编译 func1 时不知道 func2 是什么。只需切换这些函数的顺序即可,因此 func2func1

之前
@nb.jit("f8(f8)", nopython=True)
def func2(a):
    out = a + 1
    return out

@nb.jit("f8[:](f8[:])", nopython=True)
def func1(arr):
    out = np.empty_like(arr)
    for i in range(arr.shape[0]):
        out[i] = func2(arr[i])
    return out

这是为什么?我有一种感觉,纯 python 模式有效,因为 python 是动态类型而不是编译的,而 numba,使用 JIT,根据定义编译函数(因此可能需要完全了解其中发生的一切每个功能?)。但是我不明白为什么numba在遇到一个它没见过的函数时不在范围内搜索所有函数。

简短版本 - 删除 "f8[:](f8[:])"

你的直觉是对的。 Python 函数在 调用 时被查找,这就是为什么它们可以乱序定义的原因。查看带有 dis(反汇编)模块的 python 字节码可以清楚地看到名称 b 每次调用函数 a 时都会将其作为全局查找。

def a():
    return b()

def b():
    return 2

import dis
dis.dis(a)
#  2           0 LOAD_GLOBAL              0 (b)
#              2 CALL_FUNCTION            0
#              4 RETURN_VALUE

在 nopython 模式下,numba 需要静态地 知道每个被调用的函数的地址 - 这使得代码更快(不再运行时lookup),也为其他优化打开了大门,比如内联。

也就是说,numba 可以 处理这种情况。通过指定类型签名 ("f8[:](f8[:])"),您可以强制提前编译。省略它,一个数字将推迟到第一个调用它的函数,它会起作用。