最快的二重积分法
Fastest Double Integration Method
我正在使用 scipy 的二重积分 dblquad
并且我正在尝试提高速度。我检查了在线提出的解决方案,但无法使它们起作用。
为了缓解这个问题,我准备了下面的比较。我做错了什么或者我可以做些什么来提高速度?
from scipy import integrate
import timeit
from numba import njit, jit
def bb_pure(q, z, x_loc, y_loc, B, L):
def f(y, x):
return (
3
* q
* z ** 3
/ (((x - B / 2 + x_loc) ** 2 + (y - L / 2 + y_loc) ** 2 + z * z) ** 2.5
)
)
return integrate.dblquad(f, 0, B, lambda x: 0, lambda x: L)[0]
def bb_numbanjit(q, z, x_loc, y_loc, B, L):
@njit
def f(y, x):
return (
3
* q
* z ** 3
/ (((x - B / 2 + x_loc) ** 2 + (y - L / 2 + y_loc) ** 2 + z * z) ** 2.5
)
)
return integrate.dblquad(f, 0, B, lambda x: 0, lambda x: L)[0]
def bb_numbajit(q, z, x_loc, y_loc, B, L):
@jit
def f(y, x):
return (
3
* q
* z ** 3
/ (((x - B / 2 + x_loc) ** 2 + (y - L / 2 + y_loc) ** 2 + z * z) ** 2.5
)
)
return integrate.dblquad(f, 0, B, lambda x: 0, lambda x: L)[0]
####
starttime = timeit.default_timer()
for i in range(100):
bb_pure(200, 5, 0, 0, i, i*2)
print("Pure Function:", round(timeit.default_timer() - starttime,2))
####
starttime = timeit.default_timer()
for i in range(100):
bb_numbanjit(200, 5, 0, 0, i, i*2)
print("Numba njit:", round(timeit.default_timer() - starttime,2))
####
starttime = timeit.default_timer()
for i in range(100):
bb_numbajit(200, 5, 0, 0, i, i*2)
print("Numba jit:", round(timeit.default_timer() - starttime,2))
结果
Pure Function: 3.22
Numba njit: 8.14
Numba jit: 8.15
主要问题是您正在计算 Numba 函数的编译时间。实际上,当 bb_numbanjit
被调用时,@njit
装饰器告诉 Numba 声明一个 lazily-compiled 函数 ,它在 执行第一个调用 ,因此在 integrate.dblquad
。完全相同的行为适用于 bb_numbajit
。 Numba 实现速度较慢,因为编译时间与执行时间相比相当长。问题是 Numba 函数是 closures,它读取需要新编译的本地参数。解决这个问题的典型方法是向 Numba 函数添加新参数并编译一次。由于这里需要闭包,可以使用代理闭包。这是一个例子:
@njit
def f_numba(y, x, q, z, x_loc, y_loc, B, L):
return (
3
* q
* z ** 3
/ (((x - B / 2 + x_loc) ** 2 + (y - L / 2 + y_loc) ** 2 + z * z) ** 2.5
)
)
def bb_numbanjit(q, z, x_loc, y_loc, B, L):
def f_proxy(y, x):
return f_numba(y, x, q, z, x_loc, y_loc, B, L)
return integrate.dblquad(f_proxy, 0, B, lambda x: 0, lambda x: L)[0]
这比 bb_pure
解决方案 快两倍。
这个 Numba 解决方案并没有快多少的一个原因是 Python 函数调用很昂贵,尤其是当有很多参数时。另一个问题是某些参数似乎是常量,而 Numba 并不知道这一点,因为它们作为运行时参数而不是 compile-time 常量传递。您可以将全局变量中的常量移动到让 Numba 进一步优化代码(通过 pre-computing 常量 sub-expressions)。
另请注意,Numba 函数已在内部由代理函数包装。对于这种基本的数字操作,代理函数有点昂贵(它们进行一些类型检查和 pure-Python 对象到本机值的转换)。话虽这么说,由于关闭问题,这里没什么可做的。
我正在使用 scipy 的二重积分 dblquad
并且我正在尝试提高速度。我检查了在线提出的解决方案,但无法使它们起作用。
为了缓解这个问题,我准备了下面的比较。我做错了什么或者我可以做些什么来提高速度?
from scipy import integrate
import timeit
from numba import njit, jit
def bb_pure(q, z, x_loc, y_loc, B, L):
def f(y, x):
return (
3
* q
* z ** 3
/ (((x - B / 2 + x_loc) ** 2 + (y - L / 2 + y_loc) ** 2 + z * z) ** 2.5
)
)
return integrate.dblquad(f, 0, B, lambda x: 0, lambda x: L)[0]
def bb_numbanjit(q, z, x_loc, y_loc, B, L):
@njit
def f(y, x):
return (
3
* q
* z ** 3
/ (((x - B / 2 + x_loc) ** 2 + (y - L / 2 + y_loc) ** 2 + z * z) ** 2.5
)
)
return integrate.dblquad(f, 0, B, lambda x: 0, lambda x: L)[0]
def bb_numbajit(q, z, x_loc, y_loc, B, L):
@jit
def f(y, x):
return (
3
* q
* z ** 3
/ (((x - B / 2 + x_loc) ** 2 + (y - L / 2 + y_loc) ** 2 + z * z) ** 2.5
)
)
return integrate.dblquad(f, 0, B, lambda x: 0, lambda x: L)[0]
####
starttime = timeit.default_timer()
for i in range(100):
bb_pure(200, 5, 0, 0, i, i*2)
print("Pure Function:", round(timeit.default_timer() - starttime,2))
####
starttime = timeit.default_timer()
for i in range(100):
bb_numbanjit(200, 5, 0, 0, i, i*2)
print("Numba njit:", round(timeit.default_timer() - starttime,2))
####
starttime = timeit.default_timer()
for i in range(100):
bb_numbajit(200, 5, 0, 0, i, i*2)
print("Numba jit:", round(timeit.default_timer() - starttime,2))
结果
Pure Function: 3.22
Numba njit: 8.14
Numba jit: 8.15
主要问题是您正在计算 Numba 函数的编译时间。实际上,当 bb_numbanjit
被调用时,@njit
装饰器告诉 Numba 声明一个 lazily-compiled 函数 ,它在 执行第一个调用 ,因此在 integrate.dblquad
。完全相同的行为适用于 bb_numbajit
。 Numba 实现速度较慢,因为编译时间与执行时间相比相当长。问题是 Numba 函数是 closures,它读取需要新编译的本地参数。解决这个问题的典型方法是向 Numba 函数添加新参数并编译一次。由于这里需要闭包,可以使用代理闭包。这是一个例子:
@njit
def f_numba(y, x, q, z, x_loc, y_loc, B, L):
return (
3
* q
* z ** 3
/ (((x - B / 2 + x_loc) ** 2 + (y - L / 2 + y_loc) ** 2 + z * z) ** 2.5
)
)
def bb_numbanjit(q, z, x_loc, y_loc, B, L):
def f_proxy(y, x):
return f_numba(y, x, q, z, x_loc, y_loc, B, L)
return integrate.dblquad(f_proxy, 0, B, lambda x: 0, lambda x: L)[0]
这比 bb_pure
解决方案 快两倍。
这个 Numba 解决方案并没有快多少的一个原因是 Python 函数调用很昂贵,尤其是当有很多参数时。另一个问题是某些参数似乎是常量,而 Numba 并不知道这一点,因为它们作为运行时参数而不是 compile-time 常量传递。您可以将全局变量中的常量移动到让 Numba 进一步优化代码(通过 pre-computing 常量 sub-expressions)。
另请注意,Numba 函数已在内部由代理函数包装。对于这种基本的数字操作,代理函数有点昂贵(它们进行一些类型检查和 pure-Python 对象到本机值的转换)。话虽这么说,由于关闭问题,这里没什么可做的。