为什么依赖 numba jitt'ed 函数的顺序很重要?
Why does ordering of dependent numba jitt'ed functions matter?
在python中,您可以定义多个以任意顺序相互调用的函数,并且在运行时将调用这些函数。这些函数在脚本中定义的顺序并不重要,只要它们存在即可。例如,以下是有效的并且可以工作
import numpy as np
def func1(arr):
out = np.empty_like(arr)
for i in range(arr.shape[0]):
out[i] = func2(arr[i]) # calling func2 here which is defined below
return out
def func2(a):
out = a + 1
return out
func1
可以调用 func2
,即使 func2
是在 func1
.
之后定义的
但是,如果我用 numba 修饰这些函数,我会得到一个错误
import numpy as np
import numba as nb
@nb.jit("f8[:](f8[:])", nopython=True)
def func1(arr):
out = np.empty_like(arr)
for i in range(arr.shape[0]):
out[i] = func2(arr[i])
return out
@nb.jit("f8(f8)", nopython=True)
def func2(a):
out = a + 1
return out
>>> TypingError: Failed in nopython mode pipeline (step: nopython frontend)
Untyped global name 'func2': cannot determine Numba type of <class
'numba.ir.UndefinedType'>
所以 numba 在使用 JIT 编译 func1
时不知道 func2
是什么。只需切换这些函数的顺序即可,因此 func2
在 func1
之前
@nb.jit("f8(f8)", nopython=True)
def func2(a):
out = a + 1
return out
@nb.jit("f8[:](f8[:])", nopython=True)
def func1(arr):
out = np.empty_like(arr)
for i in range(arr.shape[0]):
out[i] = func2(arr[i])
return out
这是为什么?我有一种感觉,纯 python 模式有效,因为 python 是动态类型而不是编译的,而 numba,使用 JIT,根据定义编译函数(因此可能需要完全了解其中发生的一切每个功能?)。但是我不明白为什么numba在遇到一个它没见过的函数时不在范围内搜索所有函数。
简短版本 - 删除 "f8[:](f8[:])"
你的直觉是对的。 Python 函数在 调用 时被查找,这就是为什么它们可以乱序定义的原因。查看带有 dis
(反汇编)模块的 python 字节码可以清楚地看到名称 b
每次调用函数 a
时都会将其作为全局查找。
def a():
return b()
def b():
return 2
import dis
dis.dis(a)
# 2 0 LOAD_GLOBAL 0 (b)
# 2 CALL_FUNCTION 0
# 4 RETURN_VALUE
在 nopython 模式下,numba 需要静态地 知道每个被调用的函数的地址 - 这使得代码更快(不再运行时lookup),也为其他优化打开了大门,比如内联。
也就是说,numba 可以 处理这种情况。通过指定类型签名 ("f8[:](f8[:])"
),您可以强制提前编译。省略它,一个数字将推迟到第一个调用它的函数,它会起作用。
在python中,您可以定义多个以任意顺序相互调用的函数,并且在运行时将调用这些函数。这些函数在脚本中定义的顺序并不重要,只要它们存在即可。例如,以下是有效的并且可以工作
import numpy as np
def func1(arr):
out = np.empty_like(arr)
for i in range(arr.shape[0]):
out[i] = func2(arr[i]) # calling func2 here which is defined below
return out
def func2(a):
out = a + 1
return out
func1
可以调用 func2
,即使 func2
是在 func1
.
但是,如果我用 numba 修饰这些函数,我会得到一个错误
import numpy as np
import numba as nb
@nb.jit("f8[:](f8[:])", nopython=True)
def func1(arr):
out = np.empty_like(arr)
for i in range(arr.shape[0]):
out[i] = func2(arr[i])
return out
@nb.jit("f8(f8)", nopython=True)
def func2(a):
out = a + 1
return out
>>> TypingError: Failed in nopython mode pipeline (step: nopython frontend)
Untyped global name 'func2': cannot determine Numba type of <class
'numba.ir.UndefinedType'>
所以 numba 在使用 JIT 编译 func1
时不知道 func2
是什么。只需切换这些函数的顺序即可,因此 func2
在 func1
@nb.jit("f8(f8)", nopython=True)
def func2(a):
out = a + 1
return out
@nb.jit("f8[:](f8[:])", nopython=True)
def func1(arr):
out = np.empty_like(arr)
for i in range(arr.shape[0]):
out[i] = func2(arr[i])
return out
这是为什么?我有一种感觉,纯 python 模式有效,因为 python 是动态类型而不是编译的,而 numba,使用 JIT,根据定义编译函数(因此可能需要完全了解其中发生的一切每个功能?)。但是我不明白为什么numba在遇到一个它没见过的函数时不在范围内搜索所有函数。
简短版本 - 删除 "f8[:](f8[:])"
你的直觉是对的。 Python 函数在 调用 时被查找,这就是为什么它们可以乱序定义的原因。查看带有 dis
(反汇编)模块的 python 字节码可以清楚地看到名称 b
每次调用函数 a
时都会将其作为全局查找。
def a():
return b()
def b():
return 2
import dis
dis.dis(a)
# 2 0 LOAD_GLOBAL 0 (b)
# 2 CALL_FUNCTION 0
# 4 RETURN_VALUE
在 nopython 模式下,numba 需要静态地 知道每个被调用的函数的地址 - 这使得代码更快(不再运行时lookup),也为其他优化打开了大门,比如内联。
也就是说,numba 可以 处理这种情况。通过指定类型签名 ("f8[:](f8[:])"
),您可以强制提前编译。省略它,一个数字将推迟到第一个调用它的函数,它会起作用。