使用函数对象作为 numba njit 函数的参数
Using a function object as an argument for numba njit function
我想做一个通用的函数,它接受一个函数对象作为参数。
最简单的情况之一:
import numpy as np
import numba as nb
@nb.njit()
def test(a, f=np.median):
return f(a)
test(np.arange(10), np.mean)
给出错误,尽管 test(np.arange(10))
按预期工作。
错误:
TypingError: Failed in nopython mode pipeline (step: nopython frontend)
non-precise type pyobject
[1] During: typing of argument at <ipython-input-54-52cead0f097d> (5)
File "<ipython-input-54-52cead0f097d>", line 5:
def test(a, f=np.median):
return f(a)
^
This error may have been caused by the following argument(s):
- argument 1: cannot determine Numba type of <class 'function'>
This is not usually a problem with Numba itself but instead often caused by
the use of unsupported features or an issue in resolving types.
这是不允许的还是我遗漏了什么?
使用函数作为参数对 numba 来说很棘手,而且非常昂贵。 Frequently Asked Questions: "1.18.1.1. Can I pass a function as an argument to a jitted function?":
中提到了这一点
1.18.1.1. Can I pass a function as an argument to a jitted function?
As of Numba 0.39, you can, so long as the function argument has also been JIT-compiled:
@jit(nopython=True)
def f(g, x):
return g(x) + g(-x)
result = f(jitted_g_function, 1)
However, dispatching with arguments that are functions has extra overhead. If this matters for your application, you can also use a factory function to capture the function argument in a closure:
def make_f(g):
# Note: a new f() is created each time make_f() is called!
@jit(nopython=True)
def f(x):
return g(x) + g(-x)
return f
f = make_f(jitted_g_function)
result = f(1)
Improving the dispatch performance of functions in Numba is an ongoing task.
这意味着您可以选择使用函数工厂:
import numpy as np
import numba as nb
def test(a, func=np.median):
@nb.njit
def _test(a):
return func(a)
return _test(a)
>>> test(np.arange(10))
4.5
>>> test(np.arange(10), np.min)
0
>>> test(np.arange(10), np.mean)
4.5
或者在将函数参数作为参数传递之前将函数参数包装为 jitted-function:
import numpy as np
import numba as nb
@nb.njit()
def test(a, f=np.median):
return f(a)
@nb.njit
def wrapped_mean(a):
return np.mean(a)
@nb.njit
def wrapped_median(a):
return np.median(a)
>>> test(np.arange(10))
4.5
>>> test(np.arange(10), wrapped_mean)
4.5
>>> test(np.arange(10), wrapped_median)
4.5
这两个选项都有很多样板文件,并不像人们希望的那样直截了当。
函数工厂方法也会重复创建和编译函数,因此如果您经常使用与参数相同的函数调用它,您可以使用字典来存储已知的编译函数:
import numpy as np
import numba as nb
_precompiled_funcs = {}
def test(a, func=np.median):
if func not in _precompiled_funcs:
@nb.njit
def _test(arr):
return func(arr)
result = _test(a)
_precompiled_funcs[func] = _test
return result
return _precompiled_funcs[func](a)
另一种方法(使用 wrapped 和 jitted 函数)也有一些开销,但是只要您传入的数组具有大量元素(>1000),它就不会真正引人注意。
如果您展示的函数确实是您想要使用的函数,我根本不会在其上使用 numba。使用 Python + NumPy 这样简单的任务不会锻炼 numba 的强度(索引和迭代数组或大量数字运算)应该更快(或同样快)并且更容易调试和理解:
import numba as nb
def test(a, f=np.median):
return f(a)
我想做一个通用的函数,它接受一个函数对象作为参数。
最简单的情况之一:
import numpy as np
import numba as nb
@nb.njit()
def test(a, f=np.median):
return f(a)
test(np.arange(10), np.mean)
给出错误,尽管 test(np.arange(10))
按预期工作。
错误:
TypingError: Failed in nopython mode pipeline (step: nopython frontend)
non-precise type pyobject
[1] During: typing of argument at <ipython-input-54-52cead0f097d> (5)
File "<ipython-input-54-52cead0f097d>", line 5:
def test(a, f=np.median):
return f(a)
^
This error may have been caused by the following argument(s):
- argument 1: cannot determine Numba type of <class 'function'>
This is not usually a problem with Numba itself but instead often caused by
the use of unsupported features or an issue in resolving types.
这是不允许的还是我遗漏了什么?
使用函数作为参数对 numba 来说很棘手,而且非常昂贵。 Frequently Asked Questions: "1.18.1.1. Can I pass a function as an argument to a jitted function?":
中提到了这一点1.18.1.1. Can I pass a function as an argument to a jitted function?
As of Numba 0.39, you can, so long as the function argument has also been JIT-compiled:
@jit(nopython=True) def f(g, x): return g(x) + g(-x) result = f(jitted_g_function, 1)
However, dispatching with arguments that are functions has extra overhead. If this matters for your application, you can also use a factory function to capture the function argument in a closure:
def make_f(g): # Note: a new f() is created each time make_f() is called! @jit(nopython=True) def f(x): return g(x) + g(-x) return f f = make_f(jitted_g_function) result = f(1)
Improving the dispatch performance of functions in Numba is an ongoing task.
这意味着您可以选择使用函数工厂:
import numpy as np
import numba as nb
def test(a, func=np.median):
@nb.njit
def _test(a):
return func(a)
return _test(a)
>>> test(np.arange(10))
4.5
>>> test(np.arange(10), np.min)
0
>>> test(np.arange(10), np.mean)
4.5
或者在将函数参数作为参数传递之前将函数参数包装为 jitted-function:
import numpy as np
import numba as nb
@nb.njit()
def test(a, f=np.median):
return f(a)
@nb.njit
def wrapped_mean(a):
return np.mean(a)
@nb.njit
def wrapped_median(a):
return np.median(a)
>>> test(np.arange(10))
4.5
>>> test(np.arange(10), wrapped_mean)
4.5
>>> test(np.arange(10), wrapped_median)
4.5
这两个选项都有很多样板文件,并不像人们希望的那样直截了当。
函数工厂方法也会重复创建和编译函数,因此如果您经常使用与参数相同的函数调用它,您可以使用字典来存储已知的编译函数:
import numpy as np
import numba as nb
_precompiled_funcs = {}
def test(a, func=np.median):
if func not in _precompiled_funcs:
@nb.njit
def _test(arr):
return func(arr)
result = _test(a)
_precompiled_funcs[func] = _test
return result
return _precompiled_funcs[func](a)
另一种方法(使用 wrapped 和 jitted 函数)也有一些开销,但是只要您传入的数组具有大量元素(>1000),它就不会真正引人注意。
如果您展示的函数确实是您想要使用的函数,我根本不会在其上使用 numba。使用 Python + NumPy 这样简单的任务不会锻炼 numba 的强度(索引和迭代数组或大量数字运算)应该更快(或同样快)并且更容易调试和理解:
import numba as nb
def test(a, f=np.median):
return f(a)