使用 numba 优化 while 循环以实现容错
Optimizing while cycle with numba for error tolerance
我在使用numba做优化的时候有疑问。我正在编写一个定点迭代来计算某个名为 gamma 的数组的值,它满足方程 f(gamma)=gamma。我正在尝试使用 python 包 Numba 优化此功能。好像是这样。
@jit
def fixed_point(gamma_guess):
for i in range(17):
gamma_guess=f(gamma_guess)
return gamma_guess
Numba 能够很好地优化这个功能,因为它知道要执行多少次操作,17 次,而且速度很快。但是我需要控制我想要的伽玛的误差容忍度,我的意思是,通过定点迭代获得的伽玛和下一个伽玛的差异应该小于某个数字epsilon = 0.01,然后我尝试了
@jit
def fixed_point(gamma_guess):
err=1000
gamma_old=gamma_guess.copy()
while(error>0.01):
gamma_guess=f(gamma_guess)
err=np.max(abs(gamma_guess-gamma_old))
gamma_old=gamma_guess.copy()
return gamma_guess
它也能工作并计算出想要的结果,但不如上次实现快,它慢得多。我认为这是因为 Numba 无法很好地优化 while 循环,因为我们不知道它什么时候会停止。有没有一种方法可以优化这个并且 运行 和上次实施一样快?
编辑:
这是我使用的 f
from scipy import fftpack as sp
S=0.01
Amu=0.7
@jit
def f(gammaa,z,zal,kappa):
ka=sp.diff(kappa)
gamma0=gammaa
for i in range(N):
suma=0
for j in range(N):
if (abs(j-i))%2 ==1:
if((z[i]-z[j])==0):
suma+=(gamma0[j]/(z[i]-z[j]))
gamma0[i]=2.0*Amu*np.real(-(zal[i]/z[i])+zal[i]*(1.0/(2*np.pi*1j))*suma*2*h)+S*ka[i]
return gamma0
我总是使用 np.ones(2048)*0.5
作为初始猜测,我传递给函数的其他参数是 z=np.cos(alphas)+1j*(np.sin(alphas)+0.1)
、 zal=-np.sin(alphas)+1j*np.cos(alphas)
、 kappa=np.ones(2048)
和 alphas=np.arange(0,2*np.pi,2*np.pi/2048)
我做了一个小测试脚本,看看我是否可以重现你的错误:
import numba as nb
from IPython import get_ipython
ipython = get_ipython()
@nb.jit(nopython=True)
def f(x):
return (x+1)/x
def fixed_point_for(x):
for _ in range(17):
x = f(x)
return x
@nb.jit(nopython=True)
def fixed_point_for_nb(x):
for _ in range(17):
x = f(x)
return x
def fixed_point_while(x):
error=1
x_old = x
while error>0.01:
x = f(x)
error = abs(x_old-x)
x_old = x
return x
@nb.jit(nopython=True)
def fixed_point_while_nb(x):
error=1
x_old = x
while error>0.01:
x = f(x)
error = abs(x_old-x)
x_old = x
return x
print("for loop without numba:")
ipython.magic("%timeit fixed_point_for(10)")
print("for loop with numba:")
ipython.magic("%timeit fixed_point_for_nb(10)")
print("while loop without numba:")
ipython.magic("%timeit fixed_point_while(10)")
print("for loop with numba:")
ipython.magic("%timeit fixed_point_while_nb(10)")
因为我不知道你的情况 f
我只是使用了我能想到的最简单的稳定功能。然后我 运行 使用和不使用 numba
进行测试,两次都使用 for
和 while
循环。我机器上的结果是:
for loop without numba:
3.35 µs ± 8.72 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
for loop with numba:
282 ns ± 1.07 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)
while loop without numba:
1.86 µs ± 7.09 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)
for loop with numba:
214 ns ± 1.36 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)
出现以下想法:
- 不可能,你的函数不可优化,因为你的
for
循环很快(至少你是这么说的;你在没有 numba
的情况下测试过吗?)。
- 可能是您的函数需要更多循环才能收敛,就像您想象的那样
- 我们正在使用不同的软件版本。我的版本是:
numba 0.49.0
numpy 1.18.3
python 3.8.2
我在使用numba做优化的时候有疑问。我正在编写一个定点迭代来计算某个名为 gamma 的数组的值,它满足方程 f(gamma)=gamma。我正在尝试使用 python 包 Numba 优化此功能。好像是这样。
@jit
def fixed_point(gamma_guess):
for i in range(17):
gamma_guess=f(gamma_guess)
return gamma_guess
Numba 能够很好地优化这个功能,因为它知道要执行多少次操作,17 次,而且速度很快。但是我需要控制我想要的伽玛的误差容忍度,我的意思是,通过定点迭代获得的伽玛和下一个伽玛的差异应该小于某个数字epsilon = 0.01,然后我尝试了
@jit
def fixed_point(gamma_guess):
err=1000
gamma_old=gamma_guess.copy()
while(error>0.01):
gamma_guess=f(gamma_guess)
err=np.max(abs(gamma_guess-gamma_old))
gamma_old=gamma_guess.copy()
return gamma_guess
它也能工作并计算出想要的结果,但不如上次实现快,它慢得多。我认为这是因为 Numba 无法很好地优化 while 循环,因为我们不知道它什么时候会停止。有没有一种方法可以优化这个并且 运行 和上次实施一样快?
编辑:
这是我使用的 f
from scipy import fftpack as sp
S=0.01
Amu=0.7
@jit
def f(gammaa,z,zal,kappa):
ka=sp.diff(kappa)
gamma0=gammaa
for i in range(N):
suma=0
for j in range(N):
if (abs(j-i))%2 ==1:
if((z[i]-z[j])==0):
suma+=(gamma0[j]/(z[i]-z[j]))
gamma0[i]=2.0*Amu*np.real(-(zal[i]/z[i])+zal[i]*(1.0/(2*np.pi*1j))*suma*2*h)+S*ka[i]
return gamma0
我总是使用 np.ones(2048)*0.5
作为初始猜测,我传递给函数的其他参数是 z=np.cos(alphas)+1j*(np.sin(alphas)+0.1)
、 zal=-np.sin(alphas)+1j*np.cos(alphas)
、 kappa=np.ones(2048)
和 alphas=np.arange(0,2*np.pi,2*np.pi/2048)
我做了一个小测试脚本,看看我是否可以重现你的错误:
import numba as nb
from IPython import get_ipython
ipython = get_ipython()
@nb.jit(nopython=True)
def f(x):
return (x+1)/x
def fixed_point_for(x):
for _ in range(17):
x = f(x)
return x
@nb.jit(nopython=True)
def fixed_point_for_nb(x):
for _ in range(17):
x = f(x)
return x
def fixed_point_while(x):
error=1
x_old = x
while error>0.01:
x = f(x)
error = abs(x_old-x)
x_old = x
return x
@nb.jit(nopython=True)
def fixed_point_while_nb(x):
error=1
x_old = x
while error>0.01:
x = f(x)
error = abs(x_old-x)
x_old = x
return x
print("for loop without numba:")
ipython.magic("%timeit fixed_point_for(10)")
print("for loop with numba:")
ipython.magic("%timeit fixed_point_for_nb(10)")
print("while loop without numba:")
ipython.magic("%timeit fixed_point_while(10)")
print("for loop with numba:")
ipython.magic("%timeit fixed_point_while_nb(10)")
因为我不知道你的情况 f
我只是使用了我能想到的最简单的稳定功能。然后我 运行 使用和不使用 numba
进行测试,两次都使用 for
和 while
循环。我机器上的结果是:
for loop without numba:
3.35 µs ± 8.72 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
for loop with numba:
282 ns ± 1.07 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)
while loop without numba:
1.86 µs ± 7.09 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)
for loop with numba:
214 ns ± 1.36 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)
出现以下想法:
- 不可能,你的函数不可优化,因为你的
for
循环很快(至少你是这么说的;你在没有numba
的情况下测试过吗?)。 - 可能是您的函数需要更多循环才能收敛,就像您想象的那样
- 我们正在使用不同的软件版本。我的版本是:
numba 0.49.0
numpy 1.18.3
python 3.8.2