nopython 模式下 Numba jitted 函数的线性组合
Linear Combinations of Numba jitted functions in nopython mode
我一直在尝试使用 numba 自动生成/jit 函数。
您可以从一个 jit 函数中调用其他 jit 函数,因此如果您有一组特定的函数,就很容易在我想要的功能中进行硬编码:
from numba import jit
@jit(nopython=True)
def f1(x):
return (x - 2.0)**2
@jit(nopython=True)
def f2(x):
return (x - 5.0)**2
def hardcoded(x, c):
@jit(nopython=True)
def f(x):
return c[0] * f1(x) + c[1] * f2(x)
return f
lincomb = hardcoded(3, (0.5, 0.5))
print(lincomb(2))
Out: 4.5
但是,假设您事先不知道 f1、f2 是什么。我希望能够使用工厂生成函数,然后使用另一个工厂生成其线性组合:
def f_factory(x0):
@jit(nopython=True)
def f(x):
return (x - x0)**2
return f
def linear_comb(funcs, coeffs, nopython=True):
@jit(nopython=nopython)
def lc(x):
total = 0.0
for f, c in zip(funcs, coeffs):
total += c * f(x)
return total
return lc
并在运行时调用它。这在没有 nopython 模式的情况下有效:
funcs = (f_factory(2.0), f_factory(5.0))
lc = linear_comb(funcs, (0.5, 0.5), nopython=False)
print(lc(2))
Out: 4.5
但是没有使用 nopython 模式。
lc = linear_comb(funcs, (0.5, 0.5), nopython=True)
print(lc(2))
TypingError: Failed at nopython (nopython frontend)
Untyped global name 'funcs': cannot determine Numba type of <class 'tuple'>
File "<ipython-input-100-2d3fb6214044>", line 11
看来 numba 在处理 jit 函数的元组时遇到了问题。有什么方法可以让这种行为起作用吗?
函数集和 c 可能会变大,所以我真的很想让它在 nopython 模式下编译。
可能有更好的方法来做到这一点,但作为一种 hacky 解决方法,您确实可以使用一些 "templating" 来生成元组中每个函数的唯一名称和调用。
def linear_comb(funcs, coeffs, nopython=True):
scope = {'coeffs': coeffs}
stmts = [
'def lc(x):',
' total = 0.0',
]
for i, f in enumerate(funcs):
# give each function a unique name
scope[f'_f{i}'] = f
# codegen for total line
stmts.append(f' total += coeffs[{i}] * _f{i}(x)')
stmts.append(' return total')
code = '\n'.join(stmts)
exec(code, scope)
lc = jit(nopython=nopython)(scope['lc'])
return lc
lc = linear_comb(funcs, (0.5, 0.5), nopython=True)
lc(2)
Out[103]: 4.5
我一直在尝试使用 numba 自动生成/jit 函数。
您可以从一个 jit 函数中调用其他 jit 函数,因此如果您有一组特定的函数,就很容易在我想要的功能中进行硬编码:
from numba import jit
@jit(nopython=True)
def f1(x):
return (x - 2.0)**2
@jit(nopython=True)
def f2(x):
return (x - 5.0)**2
def hardcoded(x, c):
@jit(nopython=True)
def f(x):
return c[0] * f1(x) + c[1] * f2(x)
return f
lincomb = hardcoded(3, (0.5, 0.5))
print(lincomb(2))
Out: 4.5
但是,假设您事先不知道 f1、f2 是什么。我希望能够使用工厂生成函数,然后使用另一个工厂生成其线性组合:
def f_factory(x0):
@jit(nopython=True)
def f(x):
return (x - x0)**2
return f
def linear_comb(funcs, coeffs, nopython=True):
@jit(nopython=nopython)
def lc(x):
total = 0.0
for f, c in zip(funcs, coeffs):
total += c * f(x)
return total
return lc
并在运行时调用它。这在没有 nopython 模式的情况下有效:
funcs = (f_factory(2.0), f_factory(5.0))
lc = linear_comb(funcs, (0.5, 0.5), nopython=False)
print(lc(2))
Out: 4.5
但是没有使用 nopython 模式。
lc = linear_comb(funcs, (0.5, 0.5), nopython=True)
print(lc(2))
TypingError: Failed at nopython (nopython frontend)
Untyped global name 'funcs': cannot determine Numba type of <class 'tuple'>
File "<ipython-input-100-2d3fb6214044>", line 11
看来 numba 在处理 jit 函数的元组时遇到了问题。有什么方法可以让这种行为起作用吗?
函数集和 c 可能会变大,所以我真的很想让它在 nopython 模式下编译。
可能有更好的方法来做到这一点,但作为一种 hacky 解决方法,您确实可以使用一些 "templating" 来生成元组中每个函数的唯一名称和调用。
def linear_comb(funcs, coeffs, nopython=True):
scope = {'coeffs': coeffs}
stmts = [
'def lc(x):',
' total = 0.0',
]
for i, f in enumerate(funcs):
# give each function a unique name
scope[f'_f{i}'] = f
# codegen for total line
stmts.append(f' total += coeffs[{i}] * _f{i}(x)')
stmts.append(' return total')
code = '\n'.join(stmts)
exec(code, scope)
lc = jit(nopython=nopython)(scope['lc'])
return lc
lc = linear_comb(funcs, (0.5, 0.5), nopython=True)
lc(2)
Out[103]: 4.5