最快的二重积分法

Fastest Double Integration Method

我正在使用 scipy 的二重积分 dblquad 并且我正在尝试提高速度。我检查了在线提出的解决方案,但无法使它们起作用。 为了缓解这个问题,我准备了下面的比较。我做错了什么或者我可以做些什么来提高速度?

from scipy import integrate
import timeit
from numba import njit, jit

def bb_pure(q, z, x_loc, y_loc, B, L):

    def f(y, x):
        return (
            3
            * q
            * z ** 3
            / (((x - B / 2 + x_loc) ** 2 + (y - L / 2 + y_loc) ** 2 + z * z) ** 2.5
            )
        )

    return integrate.dblquad(f, 0, B, lambda x: 0, lambda x: L)[0]

def bb_numbanjit(q, z, x_loc, y_loc, B, L):
    
    @njit
    def f(y, x):
        return (
            3
            * q
            * z ** 3
            / (((x - B / 2 + x_loc) ** 2 + (y - L / 2 + y_loc) ** 2 + z * z) ** 2.5
            )
        )

    return integrate.dblquad(f, 0, B, lambda x: 0, lambda x: L)[0]

def bb_numbajit(q, z, x_loc, y_loc, B, L):
    
    @jit
    def f(y, x):
        return (
            3
            * q
            * z ** 3
            / (((x - B / 2 + x_loc) ** 2 + (y - L / 2 + y_loc) ** 2 + z * z) ** 2.5
            )
        )

    return integrate.dblquad(f, 0, B, lambda x: 0, lambda x: L)[0]

####

starttime = timeit.default_timer()
for i in range(100):
    bb_pure(200, 5, 0, 0, i, i*2)

print("Pure Function:", round(timeit.default_timer() - starttime,2))

####

starttime = timeit.default_timer()
for i in range(100):
    bb_numbanjit(200, 5, 0, 0, i, i*2)

print("Numba njit:", round(timeit.default_timer() - starttime,2))

####

starttime = timeit.default_timer()
for i in range(100):
    bb_numbajit(200, 5, 0, 0, i, i*2)

print("Numba jit:", round(timeit.default_timer() - starttime,2))

结果

Pure Function: 3.22
Numba njit: 8.14
Numba jit: 8.15

主要问题是您正在计算 Numba 函数的编译时间。实际上,当 bb_numbanjit 被调用时,@njit 装饰器告诉 Numba 声明一个 lazily-compiled 函数 ,它在 执行第一个调用 ,因此在 integrate.dblquad。完全相同的行为适用于 bb_numbajit。 Numba 实现速度较慢,因为编译时间与执行时间相比相当长。问题是 Numba 函数是 closures,它读取需要新编译的本地参数。解决这个问题的典型方法是向 Numba 函数添​​加新参数并编译一次。由于这里需要闭包,可以使用代理闭包。这是一个例子:

@njit
def f_numba(y, x, q, z, x_loc, y_loc, B, L):
    return (
        3
        * q
        * z ** 3
        / (((x - B / 2 + x_loc) ** 2 + (y - L / 2 + y_loc) ** 2 + z * z) ** 2.5
        )
    )

def bb_numbanjit(q, z, x_loc, y_loc, B, L):
    def f_proxy(y, x):
        return f_numba(y, x, q, z, x_loc, y_loc, B, L)

    return integrate.dblquad(f_proxy, 0, B, lambda x: 0, lambda x: L)[0]

这比 bb_pure 解决方案 快两倍

这个 Numba 解决方案并没有快多少的一个原因是 Python 函数调用很昂贵,尤其是当有很多参数时。另一个问题是某些参数似乎是常量,而 Numba 并不知道这一点,因为它们作为运行时参数而不是 compile-time 常量传递。您可以将全局变量中的常量移动到让 Numba 进一步优化代码(通过 pre-computing 常量 sub-expressions)。

另请注意,Numba 函数已在内部由代理函数包装。对于这种基本的数字操作,代理函数有点昂贵(它们进行一些类型检查和 pure-Python 对象到本机值的转换)。话虽这么说,由于关闭问题,这里没什么可做的。