使用 numba 优化 while 循环以实现容错

Optimizing while cycle with numba for error tolerance

我在使用numba做优化的时候有疑问。我正在编写一个定点迭代来计算某个名为 gamma 的数组的值,它满足方程 f(gamma)=gamma。我正在尝试使用 python 包 Numba 优化此功能。好像是这样。

@jit
def fixed_point(gamma_guess):
    for i in range(17):
        gamma_guess=f(gamma_guess)
    return gamma_guess

Numba 能够很好地优化这个功能,因为它知道要执行多少次操作,17 次,而且速度很快。但是我需要控制我想要的伽玛的误差容忍度,我的意思是,通过定点迭代获得的伽玛和下一个伽玛的差异应该小于某个数字epsilon = 0.01,然后我尝试了

@jit
def fixed_point(gamma_guess):
    err=1000
    gamma_old=gamma_guess.copy()
    while(error>0.01):
        gamma_guess=f(gamma_guess)
        err=np.max(abs(gamma_guess-gamma_old))
        gamma_old=gamma_guess.copy()
    return gamma_guess

它也能工作并计算出想要的结果,但不如上次实现快,它慢得多。我认为这是因为 Numba 无法很好地优化 while 循环,因为我们不知道它什么时候会停止。有没有一种方法可以优化这个并且 运行 和上次实施一样快?

编辑:

这是我使用的 f

from scipy import fftpack as sp
S=0.01
Amu=0.7
@jit 
def f(gammaa,z,zal,kappa):
    ka=sp.diff(kappa)
    gamma0=gammaa
    for i in range(N):
        suma=0
        for j in range(N):
            if (abs(j-i))%2 ==1:
                if((z[i]-z[j])==0):
                    suma+=(gamma0[j]/(z[i]-z[j]))   
        gamma0[i]=2.0*Amu*np.real(-(zal[i]/z[i])+zal[i]*(1.0/(2*np.pi*1j))*suma*2*h)+S*ka[i]
    return  gamma0

我总是使用 np.ones(2048)*0.5 作为初始猜测,我传递给函数的其他参数是 z=np.cos(alphas)+1j*(np.sin(alphas)+0.1)zal=-np.sin(alphas)+1j*np.cos(alphas)kappa=np.ones(2048)alphas=np.arange(0,2*np.pi,2*np.pi/2048)

我做了一个小测试脚本,看看我是否可以重现你的错误:

import numba as nb

from IPython import get_ipython
ipython = get_ipython()

@nb.jit(nopython=True)
def f(x):
    return (x+1)/x


def fixed_point_for(x):
    for _ in range(17):
        x = f(x)
    return x

@nb.jit(nopython=True)
def fixed_point_for_nb(x):
    for _ in range(17):
        x = f(x)
    return x

def fixed_point_while(x):
    error=1
    x_old = x
    while error>0.01:
        x = f(x)
        error = abs(x_old-x)
        x_old = x
    return x

@nb.jit(nopython=True)
def fixed_point_while_nb(x):
    error=1
    x_old = x
    while error>0.01:
        x = f(x)
        error = abs(x_old-x)
        x_old = x
    return x

print("for loop without numba:")
ipython.magic("%timeit fixed_point_for(10)")

print("for loop with numba:")
ipython.magic("%timeit fixed_point_for_nb(10)")

print("while loop without numba:")
ipython.magic("%timeit fixed_point_while(10)")

print("for loop with numba:")
ipython.magic("%timeit fixed_point_while_nb(10)")

因为我不知道你的情况 f 我只是使用了我能想到的最简单的稳定功能。然后我 运行 使用和不使用 numba 进行测试,两次都使用 forwhile 循环。我机器上的结果是:

for loop without numba:
3.35 µs ± 8.72 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
for loop with numba:
282 ns ± 1.07 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)
while loop without numba:
1.86 µs ± 7.09 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)
for loop with numba:
214 ns ± 1.36 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)

出现以下想法:

  • 不可能,你的函数不可优化,因为你的 for 循环很快(至少你是这么说的;你在没有 numba 的情况下测试过吗?)。
  • 可能是您的函数需要更多循环才能收敛,就像您想象的那样
  • 我们正在使用不同的软件版本。我的版本是:
    • numba 0.49.0
    • numpy 1.18.3
    • python 3.8.2