Numba JIT 比带有参数化函数的纯 python 慢
Numba JIT slower than pure python with parameterized function
我刚刚写了一篇 simple benchmark 比较 Numba 和 Julia,以及一些讨论。
我想知道是否可以通过某种方式修复我的 Numba 代码,或者 Numba 是否确实不支持我尝试做的事情。
我们的想法是使用 JIT 编译的正交规则来计算这个函数。
g(p) = integrate exp(p*x) with respect to x
这是简单的正交函数:
@nb.njit
def quad_trap(f,a,b,N):
h = (b-a)/N
integral = h * ( f(a) + f(b) ) / 2
for k in range(N):
xk = (b-a) * k/N + a
integral = integral + h*f(xk)
return integral
我可以将 JIT 编译函数传递给此函数,例如:
@nb.njit(nb.float64(nb.float64))
def func(x):
return math.exp(x) - 10
这比纯 Python 快大约 10-20 倍,这非常好。
现在,我想做的是传递 x 的函数并由 p 参数化,类似于:
def g(p):
@nb.njit(nb.float64(nb.float64))
def integrand(x):
return math.exp(p*x) - 10
return quad_trap(integrand, -1, 1, 10000)
而且这样做似乎会破坏 Numba,即使与纯 Python 相比,Numba 也变得非常慢。
是我做错了什么,还是 Numba 确实不支持此功能? (我确实检查了 the documentation 但我不明白问题出在哪里)。谢谢!
TL;DR: Numba 似乎还不支持此功能。
that's about 10-20X faster than pure Python, which is pretty good.
Numba 函数 quad_trap
将在您第一次调用时编译。如果参数的类型发生变化,那么 Numba 将再次重新编译该函数。编译时间通常是不可忽略的(几毫秒到几秒)。为了避免这种情况,解决方案通常是指定参数的类型。但是,据我所知,由于功能的原因,这在这里是不可能的(至少没有记录)。话虽如此,因为您肯定会使用相同的函数对 quad_trap
函数进行基准测试,Numba 不应重新编译该函数,因为提供的参数类型不会改变。
doing seems to break down Numba, which becomes incredibly slow even when compared with pure Python.
在 Numba 的最新版本中,它可以在没有警告的情况下运行,但这是因为函数 integrand
被一遍又一遍地重新编译,因为 Numba 不知道它的代码是否改变了(或者那个 functions/operators 在此函数中递归调用)。在旧版本中,Numba 可能会抱怨函数 integrand
读取的参数 p
是从其父包含函数读取的。这称为闭包。
编译器通常不太支持闭包,因为处理它们要困难得多(它们需要从父函数的堆栈中读取变量)。一个反复出现的普遍问题是闭包可以逃脱其父函数的范围并在外部调用,从而导致未定义的行为(因为闭包将尝试读取已完成函数的失效堆栈)。
一个技巧是将 @nb.njit
装饰器从 integrand
移动到 g
但 Numba 拒绝编译 g
因为它不支持可能逃脱范围的闭包它的父函数(由于前面描述的问题)。请注意,闭包不会转义它在您的案例中定义的函数,但 Numba 无法证明这一点(因为 quad_trap
函数已经编译)并且不幸的是,当函数 quad_trap
时它也无法做到这一点是内联的(虽然理论上可以证明这是安全的)。事实上 documentations 声明:
Numba now supports inner functions as long as they are non-recursive and only called locally, but not passed as argument or returned as result. The use of closure variables (variables defined in outer scopes) within an inner function is also supported.
我认为 @generated_jit
装饰器可能有助于解决此类问题,但我未能成功使其适用于您的特定情况。它至少应该有助于在定义时编译 g
(如 integrand
)而不是在第一次调用期间。
一个解决方案就是不使用闭包:
@nb.njit
def quad_trap_p(f,a,b,N,p):
h = (b-a)/N
integral = h * ( f(a,p) + f(b,p) ) / 2
for k in range(N):
xk = (b-a) * k/N + a
integral = integral + h*f(xk,p)
return integral
@nb.njit(nb.float64(nb.float64, nb.float64))
def integrand(x, p):
return math.exp(p*x) - 10
def g(p):
return quad_trap_p(integrand, -1, 1, 10000, p)
我刚刚写了一篇 simple benchmark 比较 Numba 和 Julia,以及一些讨论。
我想知道是否可以通过某种方式修复我的 Numba 代码,或者 Numba 是否确实不支持我尝试做的事情。
我们的想法是使用 JIT 编译的正交规则来计算这个函数。
g(p) = integrate exp(p*x) with respect to x
这是简单的正交函数:
@nb.njit
def quad_trap(f,a,b,N):
h = (b-a)/N
integral = h * ( f(a) + f(b) ) / 2
for k in range(N):
xk = (b-a) * k/N + a
integral = integral + h*f(xk)
return integral
我可以将 JIT 编译函数传递给此函数,例如:
@nb.njit(nb.float64(nb.float64))
def func(x):
return math.exp(x) - 10
这比纯 Python 快大约 10-20 倍,这非常好。
现在,我想做的是传递 x 的函数并由 p 参数化,类似于:
def g(p):
@nb.njit(nb.float64(nb.float64))
def integrand(x):
return math.exp(p*x) - 10
return quad_trap(integrand, -1, 1, 10000)
而且这样做似乎会破坏 Numba,即使与纯 Python 相比,Numba 也变得非常慢。
是我做错了什么,还是 Numba 确实不支持此功能? (我确实检查了 the documentation 但我不明白问题出在哪里)。谢谢!
TL;DR: Numba 似乎还不支持此功能。
that's about 10-20X faster than pure Python, which is pretty good.
Numba 函数 quad_trap
将在您第一次调用时编译。如果参数的类型发生变化,那么 Numba 将再次重新编译该函数。编译时间通常是不可忽略的(几毫秒到几秒)。为了避免这种情况,解决方案通常是指定参数的类型。但是,据我所知,由于功能的原因,这在这里是不可能的(至少没有记录)。话虽如此,因为您肯定会使用相同的函数对 quad_trap
函数进行基准测试,Numba 不应重新编译该函数,因为提供的参数类型不会改变。
doing seems to break down Numba, which becomes incredibly slow even when compared with pure Python.
在 Numba 的最新版本中,它可以在没有警告的情况下运行,但这是因为函数 integrand
被一遍又一遍地重新编译,因为 Numba 不知道它的代码是否改变了(或者那个 functions/operators 在此函数中递归调用)。在旧版本中,Numba 可能会抱怨函数 integrand
读取的参数 p
是从其父包含函数读取的。这称为闭包。
编译器通常不太支持闭包,因为处理它们要困难得多(它们需要从父函数的堆栈中读取变量)。一个反复出现的普遍问题是闭包可以逃脱其父函数的范围并在外部调用,从而导致未定义的行为(因为闭包将尝试读取已完成函数的失效堆栈)。
一个技巧是将 @nb.njit
装饰器从 integrand
移动到 g
但 Numba 拒绝编译 g
因为它不支持可能逃脱范围的闭包它的父函数(由于前面描述的问题)。请注意,闭包不会转义它在您的案例中定义的函数,但 Numba 无法证明这一点(因为 quad_trap
函数已经编译)并且不幸的是,当函数 quad_trap
时它也无法做到这一点是内联的(虽然理论上可以证明这是安全的)。事实上 documentations 声明:
Numba now supports inner functions as long as they are non-recursive and only called locally, but not passed as argument or returned as result. The use of closure variables (variables defined in outer scopes) within an inner function is also supported.
我认为 @generated_jit
装饰器可能有助于解决此类问题,但我未能成功使其适用于您的特定情况。它至少应该有助于在定义时编译 g
(如 integrand
)而不是在第一次调用期间。
一个解决方案就是不使用闭包:
@nb.njit
def quad_trap_p(f,a,b,N,p):
h = (b-a)/N
integral = h * ( f(a,p) + f(b,p) ) / 2
for k in range(N):
xk = (b-a) * k/N + a
integral = integral + h*f(xk,p)
return integral
@nb.njit(nb.float64(nb.float64, nb.float64))
def integrand(x, p):
return math.exp(p*x) - 10
def g(p):
return quad_trap_p(integrand, -1, 1, 10000, p)