加速将函数作为 numba 参数的函数
speed up function that takes a function as argument with numba
我正在尝试使用 numba
来加速将另一个函数作为参数的函数。一个最小的例子如下:
import numba as nb
def f(x):
return x*x
@nb.jit(nopython=True)
def call_func(func,x):
return func(x)
if __name__ == '__main__':
print(call_func(f,5))
然而,这不起作用,因为显然 numba
不知道如何处理该函数参数。回溯很长:
Traceback (most recent call last):
File "numba_function.py", line 15, in <module>
print(call_func(f,5))
File "/opt/local/Library/Frameworks/Python.framework/Versions/3.5/lib/python3.5/site-packages/numba/dispatcher.py", line 330, in _compile_for_args
raise e
File "/opt/local/Library/Frameworks/Python.framework/Versions/3.5/lib/python3.5/site-packages/numba/dispatcher.py", line 307, in _compile_for_args
return self.compile(tuple(argtypes))
File "/opt/local/Library/Frameworks/Python.framework/Versions/3.5/lib/python3.5/site-packages/numba/dispatcher.py", line 579, in compile
cres = self._compiler.compile(args, return_type)
File "/opt/local/Library/Frameworks/Python.framework/Versions/3.5/lib/python3.5/site-packages/numba/dispatcher.py", line 80, in compile
flags=flags, locals=self.locals)
File "/opt/local/Library/Frameworks/Python.framework/Versions/3.5/lib/python3.5/site-packages/numba/compiler.py", line 740, in compile_extra
return pipeline.compile_extra(func)
File "/opt/local/Library/Frameworks/Python.framework/Versions/3.5/lib/python3.5/site-packages/numba/compiler.py", line 360, in compile_extra
return self._compile_bytecode()
File "/opt/local/Library/Frameworks/Python.framework/Versions/3.5/lib/python3.5/site-packages/numba/compiler.py", line 699, in _compile_bytecode
return self._compile_core()
File "/opt/local/Library/Frameworks/Python.framework/Versions/3.5/lib/python3.5/site-packages/numba/compiler.py", line 686, in _compile_core
res = pm.run(self.status)
File "/opt/local/Library/Frameworks/Python.framework/Versions/3.5/lib/python3.5/site-packages/numba/compiler.py", line 246, in run
raise patched_exception
File "/opt/local/Library/Frameworks/Python.framework/Versions/3.5/lib/python3.5/site-packages/numba/compiler.py", line 238, in run
stage()
File "/opt/local/Library/Frameworks/Python.framework/Versions/3.5/lib/python3.5/site-packages/numba/compiler.py", line 452, in stage_nopython_frontend
self.locals)
File "/opt/local/Library/Frameworks/Python.framework/Versions/3.5/lib/python3.5/site-packages/numba/compiler.py", line 841, in type_inference_stage
infer.propagate()
File "/opt/local/Library/Frameworks/Python.framework/Versions/3.5/lib/python3.5/site-packages/numba/typeinfer.py", line 773, in propagate
raise errors[0]
File "/opt/local/Library/Frameworks/Python.framework/Versions/3.5/lib/python3.5/site-packages/numba/typeinfer.py", line 129, in propagate
constraint(typeinfer)
File "/opt/local/Library/Frameworks/Python.framework/Versions/3.5/lib/python3.5/site-packages/numba/typeinfer.py", line 380, in __call__
self.resolve(typeinfer, typevars, fnty)
File "/opt/local/Library/Frameworks/Python.framework/Versions/3.5/lib/python3.5/site-packages/numba/typeinfer.py", line 402, in resolve
raise TypingError(msg, loc=self.loc)
numba.errors.TypingError: Failed at nopython (nopython frontend)
Invalid usage of pyobject with parameters (int64)
No type info available for pyobject as a callable.
File "numba_function.py", line 10
[1] During: resolving callee type: pyobject
[2] During: typing of call at numba_function.py (10)
This error may have been caused by the following argument(s):
- argument 0: cannot determine Numba type of <class 'function'>
有办法解决这个问题吗?
如错误消息所示,Numba 无法处理 function
类型的值。您可以查看 the documentation Numba 可以使用哪些类型。原因是 Numba 通常无法在 noptyhon
模式下优化(jit-compile)任意函数,它们基本上被认为是一个黑盒子(事实上,传递的函数甚至可以是本机函数!)。
通常的方法是要求 Numba 优化被调用的函数。如果您不能将装饰器添加到函数中(例如,因为它不是源代码的一部分),您仍然可以手动使用它,如:
import numba as nb
def f(x):
return x*x
if __name__ == '__main__':
f_opt = nb.jit(nopython=True)(f)
print(f_opt(5))
显然,如果 f
也不能被 Numba 编译,它仍然会失败,但在那种情况下,你无能为力。
这取决于你传递给call_func
的func
是否可以在nopython
模式下编译。
如果它不能在 nopython 模式下编译那么这是不可能的,因为 numba 不支持 python 在 nopython 函数中调用(这就是为什么它是叫不python).
然而,如果它可以在 nopython 模式下编译,您可以使用闭包:
import numba as nb
def f(x):
return x*x
def call_func(func, x):
func = nb.njit(func) # compile func in nopython mode!
@nb.njit
def inner(x):
return func(x)
return inner(x)
if __name__ == '__main__':
print(call_func(f,5))
这种方法有一些明显的缺点,因为每次调用 call_func
时都需要编译 func
和 inner
。这意味着它只有在通过编译函数的加速大于编译成本时才可行。如果您多次使用相同的函数调用 call_func
,则可以减轻这种开销:
import numba as nb
def f(x):
return x*x
def call_func(func): # only take func
func = nb.njit(func) # compile func in nopython mode!
@nb.njit
def inner(x):
return func(x)
return inner # return the closure
if __name__ == '__main__':
call_func_with_f = call_func(f) # compile once
print(call_func_with_f(5)) # call the compiled version
print(call_func_with_f(5)) # call the compiled version
print(call_func_with_f(5)) # call the compiled version
print(call_func_with_f(5)) # call the compiled version
print(call_func_with_f(5)) # call the compiled version
一般说明:我不会创建带有函数参数的 numba 函数。如果您不能对函数进行硬编码,numba 就无法生成真正快速的函数,并且如果您还包括闭包的编译成本,那基本上是不值得的。
我正在尝试使用 numba
来加速将另一个函数作为参数的函数。一个最小的例子如下:
import numba as nb
def f(x):
return x*x
@nb.jit(nopython=True)
def call_func(func,x):
return func(x)
if __name__ == '__main__':
print(call_func(f,5))
然而,这不起作用,因为显然 numba
不知道如何处理该函数参数。回溯很长:
Traceback (most recent call last):
File "numba_function.py", line 15, in <module>
print(call_func(f,5))
File "/opt/local/Library/Frameworks/Python.framework/Versions/3.5/lib/python3.5/site-packages/numba/dispatcher.py", line 330, in _compile_for_args
raise e
File "/opt/local/Library/Frameworks/Python.framework/Versions/3.5/lib/python3.5/site-packages/numba/dispatcher.py", line 307, in _compile_for_args
return self.compile(tuple(argtypes))
File "/opt/local/Library/Frameworks/Python.framework/Versions/3.5/lib/python3.5/site-packages/numba/dispatcher.py", line 579, in compile
cres = self._compiler.compile(args, return_type)
File "/opt/local/Library/Frameworks/Python.framework/Versions/3.5/lib/python3.5/site-packages/numba/dispatcher.py", line 80, in compile
flags=flags, locals=self.locals)
File "/opt/local/Library/Frameworks/Python.framework/Versions/3.5/lib/python3.5/site-packages/numba/compiler.py", line 740, in compile_extra
return pipeline.compile_extra(func)
File "/opt/local/Library/Frameworks/Python.framework/Versions/3.5/lib/python3.5/site-packages/numba/compiler.py", line 360, in compile_extra
return self._compile_bytecode()
File "/opt/local/Library/Frameworks/Python.framework/Versions/3.5/lib/python3.5/site-packages/numba/compiler.py", line 699, in _compile_bytecode
return self._compile_core()
File "/opt/local/Library/Frameworks/Python.framework/Versions/3.5/lib/python3.5/site-packages/numba/compiler.py", line 686, in _compile_core
res = pm.run(self.status)
File "/opt/local/Library/Frameworks/Python.framework/Versions/3.5/lib/python3.5/site-packages/numba/compiler.py", line 246, in run
raise patched_exception
File "/opt/local/Library/Frameworks/Python.framework/Versions/3.5/lib/python3.5/site-packages/numba/compiler.py", line 238, in run
stage()
File "/opt/local/Library/Frameworks/Python.framework/Versions/3.5/lib/python3.5/site-packages/numba/compiler.py", line 452, in stage_nopython_frontend
self.locals)
File "/opt/local/Library/Frameworks/Python.framework/Versions/3.5/lib/python3.5/site-packages/numba/compiler.py", line 841, in type_inference_stage
infer.propagate()
File "/opt/local/Library/Frameworks/Python.framework/Versions/3.5/lib/python3.5/site-packages/numba/typeinfer.py", line 773, in propagate
raise errors[0]
File "/opt/local/Library/Frameworks/Python.framework/Versions/3.5/lib/python3.5/site-packages/numba/typeinfer.py", line 129, in propagate
constraint(typeinfer)
File "/opt/local/Library/Frameworks/Python.framework/Versions/3.5/lib/python3.5/site-packages/numba/typeinfer.py", line 380, in __call__
self.resolve(typeinfer, typevars, fnty)
File "/opt/local/Library/Frameworks/Python.framework/Versions/3.5/lib/python3.5/site-packages/numba/typeinfer.py", line 402, in resolve
raise TypingError(msg, loc=self.loc)
numba.errors.TypingError: Failed at nopython (nopython frontend)
Invalid usage of pyobject with parameters (int64)
No type info available for pyobject as a callable.
File "numba_function.py", line 10
[1] During: resolving callee type: pyobject
[2] During: typing of call at numba_function.py (10)
This error may have been caused by the following argument(s):
- argument 0: cannot determine Numba type of <class 'function'>
有办法解决这个问题吗?
如错误消息所示,Numba 无法处理 function
类型的值。您可以查看 the documentation Numba 可以使用哪些类型。原因是 Numba 通常无法在 noptyhon
模式下优化(jit-compile)任意函数,它们基本上被认为是一个黑盒子(事实上,传递的函数甚至可以是本机函数!)。
通常的方法是要求 Numba 优化被调用的函数。如果您不能将装饰器添加到函数中(例如,因为它不是源代码的一部分),您仍然可以手动使用它,如:
import numba as nb
def f(x):
return x*x
if __name__ == '__main__':
f_opt = nb.jit(nopython=True)(f)
print(f_opt(5))
显然,如果 f
也不能被 Numba 编译,它仍然会失败,但在那种情况下,你无能为力。
这取决于你传递给call_func
的func
是否可以在nopython
模式下编译。
如果它不能在 nopython 模式下编译那么这是不可能的,因为 numba 不支持 python 在 nopython 函数中调用(这就是为什么它是叫不python).
然而,如果它可以在 nopython 模式下编译,您可以使用闭包:
import numba as nb
def f(x):
return x*x
def call_func(func, x):
func = nb.njit(func) # compile func in nopython mode!
@nb.njit
def inner(x):
return func(x)
return inner(x)
if __name__ == '__main__':
print(call_func(f,5))
这种方法有一些明显的缺点,因为每次调用 call_func
时都需要编译 func
和 inner
。这意味着它只有在通过编译函数的加速大于编译成本时才可行。如果您多次使用相同的函数调用 call_func
,则可以减轻这种开销:
import numba as nb
def f(x):
return x*x
def call_func(func): # only take func
func = nb.njit(func) # compile func in nopython mode!
@nb.njit
def inner(x):
return func(x)
return inner # return the closure
if __name__ == '__main__':
call_func_with_f = call_func(f) # compile once
print(call_func_with_f(5)) # call the compiled version
print(call_func_with_f(5)) # call the compiled version
print(call_func_with_f(5)) # call the compiled version
print(call_func_with_f(5)) # call the compiled version
print(call_func_with_f(5)) # call the compiled version
一般说明:我不会创建带有函数参数的 numba 函数。如果您不能对函数进行硬编码,numba 就无法生成真正快速的函数,并且如果您还包括闭包的编译成本,那基本上是不值得的。