如何正确使用 Numba 加速?
how to speed up correctly using Numba?
我目前正在为我的 python 函数加速。
def d_lat(dlat,R=6.371*1e6):
return 2 * R * np.sqrt(np.sin(np.deg2rad(dlat)/2)**2)
def d_lon(lat1,lat2,dlon,R=6.371*1e6):
return 2 * R * np.sqrt(np.cos(np.deg2rad(lat1)) *
np.cos(np.deg2rad(lat2)) *
np.sin(np.deg2rad(dlon)/2)**2)
def distance(u,v,lon1,lat1):
lat2, lon2 = lat1.copy(), lon1.copy()
lat2[v>0], lat2[v<0], = lat1[v>0]+1, lat1[v<0]-1,
lon2[u>0], lon2[u<0], = lon1[u>0]+1, lon1[u<0]-1,
dlon = lon2 - lon1
dlat = lat2 - lat1
return dlon, dlat
如您所见,这是基于 numpy 的简单代码。
我看了网上的大部分文章,他们说的只是把@numba.jit作为函数前面的装饰器,然后我可以使用Numba来加速我的代码。
这是我做过的测试。
u = np.random.randn(10000)
v = np.random.randn(10000)
lon1 = np.random.uniform(-99,-96,10000)
lat1 = np.random.uniform( 23, 25,10000)
print(u)
%%timeit
for i in range(10000):
distance(u,v,lon1,lat1)
每个循环 5.61 秒 ± 58.7 毫秒(7 次运行的平均值 ± 标准差,每次 1 个循环)
添加 Numba 装饰器
@numba.njit()
def d_lat(dlat,R=6.371*1e6):
return 2 * R * np.sqrt(np.sin(np.deg2rad(dlat)/2)**2)
@numba.njit()
def d_lon(lat1,lat2,dlon,R=6.371*1e6):
return 2 * R * np.sqrt(np.cos(np.deg2rad(lat1)) *
np.cos(np.deg2rad(lat2)) *
np.sin(np.deg2rad(dlon)/2)**2)
@numba.njit()
def distance(u, v, lon1, lat1, R=6.371*1e6):
lat2, lon2 = lat1.copy(), lon1.copy()
lat2[v>0], lat2[v<0], = lat1[v>0]+1, lat1[v<0]-1,
lon2[u>0], lon2[u<0], = lon1[u>0]+1, lon1[u<0]-1,
dlat = lat2 - lat1
dlon = lon2 - lon1
return d_lon(lat1,lat2,dlon), d_lat(dlat)
%%timeit
for i in range(10000):
a,b = distance(u,v,lon1,lat1)
每个循环 7.76 秒 ± 64.9 毫秒(7 次运行的平均值 ± 标准偏差,每次 1 个循环)
正如您在上面看到的,我的 Numba 案例的计算速度比我的纯 python 案例慢。谁能帮我解决这个问题?
ps: numba 版本
llvmlite 0.32.0rc1
numba 0.49.0rc2
------ 关于宏观经济学家答案的计算测试。 ------
根据他的回答,Numba现在已经足够聪明了,如果我们希望代码被Numba装饰,最好使用普通的"Fortran"/"C"类型的样式.下面是我正在考虑的不同方法之间的计算时间比较。
def d_lat(dlat,R=6.371*1e6):
return 2 * R * np.sqrt(np.sin(np.deg2rad(dlat)/2)**2)
def d_lon(lat1,lat2,dlon,R=6.371*1e6):
return 2 * R * np.sqrt(np.cos(np.deg2rad(lat1)) *
np.cos(np.deg2rad(lat2)) *
np.sin(np.deg2rad(dlon)/2)**2)
def distance(u,v,lon1,lat1):
lat2, lon2 = lat1.copy(), lon1.copy()
lat2[v>0], lat2[v<0], = lat1[v>0]+1, lat1[v<0]-1,
lon2[u>0], lon2[u<0], = lon1[u>0]+1, lon1[u<0]-1,
dlon = lon2 - lon1
dlat = lat2 - lat1
return dlon, dlat
%%timeit
for i in range(10000):
distance(u,v,lon1,lat1)
每个循环 54 秒 ± 485 毫秒(7 次运行的平均值 ± 标准差,每次 1 个循环)
@numba.jit(nogil=True)
def d_lat(dlat,R=6.371*1e6):
return 2 * R * np.sqrt(np.sin(np.deg2rad(dlat)/2)**2)
@numba.jit(nogil=True)
def d_lon(lat1,lat2,dlon,R=6.371*1e6):
return 2 * R * np.sqrt(np.cos(np.deg2rad(lat1)) *
np.cos(np.deg2rad(lat2)) *
np.sin(np.deg2rad(dlon)/2)**2)
def distance(u, v, lon1, lat1, R=6.371*1e6):
lat2, lon2 = lat1.copy(), lon1.copy()
lat2[v>0], lat2[v<0], = lat1[v>0]+1, lat1[v<0]-1,
lon2[u>0], lon2[u<0], = lon1[u>0]+1, lon1[u<0]-1,
dlat = lat2 - lat1
dlon = lon2 - lon1
return d_lon(lat1,lat2,dlon), d_lat(dlat)
%%timeit
for i in range(10000):
a,b = distance(u,v,lon1,lat1)
每个循环 1 分钟 21 秒 ± 815 毫秒(7 次运行的平均值 ± 标准偏差,每次 1 个循环)
def d_lat(dlat,R=6.371*1e6):
return 2 * R * np.sqrt(np.sin(np.deg2rad(dlat)/2)**2)
def d_lon(lat1,lat2,dlon,R=6.371*1e6):
return 2 * R * np.sqrt(np.cos(np.deg2rad(lat1)) *
np.cos(np.deg2rad(lat2)) *
np.sin(np.deg2rad(dlon)/2)**2)
@numba.njit(nogil=True)
def distance(u, v, lon1, lat1, R=6.371*1e6):
def d_lat(dlat,R=6.371*1e6):
return 2 * R * np.sqrt(np.sin(np.deg2rad(dlat)/2)**2)
def d_lon(lat1,lat2,dlon,R=6.371*1e6):
return 2 * R * np.sqrt(np.cos(np.deg2rad(lat1)) *
np.cos(np.deg2rad(lat2)) *
np.sin(np.deg2rad(dlon)/2)**2)
lat2, lon2 = lat1.copy(), lon1.copy()
lat2[v>0], lat2[v<0], = lat1[v>0]+1, lat1[v<0]-1,
lon2[u>0], lon2[u<0], = lon1[u>0]+1, lon1[u<0]-1,
dlat = d_lat(lat2 - lat1)
dlon = d_lon(lat1,lat2,lon2 - lon1)
return dlon, dlat
%%timeit
for i in range(10000):
a,b = distance(u,v,lon1,lat1)
1min 2s ± 239 ms per loop (mean ± std.dev. of 7 runs, 1 loop each)
@numba.njit()
def d_lat(dlat,R=6.371*1e6):
return 2 * R * np.sqrt(np.sin(np.deg2rad(dlat)/2)**2)
@numba.njit()
def d_lon(lat1,lat2,dlon,R=6.371*1e6):
return 2 * R * np.sqrt(np.cos(np.deg2rad(lat1)) *
np.cos(np.deg2rad(lat2)) *
np.sin(np.deg2rad(dlon)/2)**2)
@numba.njit()
def distance(u, v, lon1, lat1):
lon2 = np.empty_like(lon1)
lat2 = np.empty_like(lat1)
dlon = np.empty_like(lon1)
dlat = np.empty_like(lat1)
for i in range(len(v)):
vi = v[i]
if vi > 0:
lat2[i] = lat1[i]+1
dlat[i] = 1
elif vi < 0:
lat2[i] = lat1[i]-1
dlat[i] = -1
else:
lat2[i] = lat1[i]
dlat[i] = 0
for i in range(len(u)):
ui = u[i]
if ui > 0:
lon2[i] = lon1[i]+1
dlon[i] = 1
elif ui < 0:
lon2[i] = lon1[i]-1
dlon[i] = -1
else:
lon2[i] = lon1[i]
dlon[i] = 0
return d_lon(lat1,lat2,dlon), d_lat(dlat)
%%timeit
for i in range(10000):
distance(u,v,lon1,lat1)
每个循环 35.9 s ± 537 ms(7 次运行的平均值 ± 标准偏差,每次 1 个循环)
有几个问题跳出来了。
首先,您在 distance
函数中的计算不必要地复杂,并且以一种可能不适合 Numba 编译器的风格编写(有很多花哨的索引,例如 lat2[v>0]
)。尽管 Numba 变得越来越聪明,但我发现以简单的、面向循环的方式编写代码仍然有很高的 return。
其次,Numba 可以通过可选参数稍微减慢速度。我发现这主要适用于 distance
函数中的可选 R
。
解决这两个问题 - 特别是,用更简单的循环替换你的矢量化代码,最大限度地减少操作 - 我们得到
形式的代码
@numba.njit()
def d_lat(dlat,R=6.371*1e6):
return 2 * R * np.sqrt(np.sin(np.deg2rad(dlat)/2)**2)
@numba.njit()
def d_lon(lat1,lat2,dlon,R=6.371*1e6):
return 2 * R * np.sqrt(np.cos(np.deg2rad(lat1)) *
np.cos(np.deg2rad(lat2)) *
np.sin(np.deg2rad(dlon)/2)**2)
@numba.njit()
def distance(u, v, lon1, lat1):
lon2 = np.empty_like(lon1)
lat2 = np.empty_like(lat1)
dlon = np.empty_like(lon1)
dlat = np.empty_like(lat1)
for i in range(len(v)):
vi = v[i]
if vi > 0:
lat2[i] = lat1[i]+1
dlat[i] = 1
elif vi < 0:
lat2[i] = lat1[i]-1
dlat[i] = -1
else:
lat2[i] = lat1[i]
dlat[i] = 0
for i in range(len(u)):
ui = u[i]
if ui > 0:
lon2[i] = lon1[i]+1
dlon[i] = 1
elif ui < 0:
lon2[i] = lon1[i]-1
dlon[i] = -1
else:
lon2[i] = lon1[i]
dlon[i] = 0
return d_lon(lat1,lat2,dlon), d_lat(dlat)
在我的(较慢的)系统上,这将编译初始成本后的时间从大约 7 秒减少到大约 4 秒。到那时,我认为成本主要由所有函数的原始成本决定 np.sin
、np.cos
、np.exp
等
我目前正在为我的 python 函数加速。
def d_lat(dlat,R=6.371*1e6):
return 2 * R * np.sqrt(np.sin(np.deg2rad(dlat)/2)**2)
def d_lon(lat1,lat2,dlon,R=6.371*1e6):
return 2 * R * np.sqrt(np.cos(np.deg2rad(lat1)) *
np.cos(np.deg2rad(lat2)) *
np.sin(np.deg2rad(dlon)/2)**2)
def distance(u,v,lon1,lat1):
lat2, lon2 = lat1.copy(), lon1.copy()
lat2[v>0], lat2[v<0], = lat1[v>0]+1, lat1[v<0]-1,
lon2[u>0], lon2[u<0], = lon1[u>0]+1, lon1[u<0]-1,
dlon = lon2 - lon1
dlat = lat2 - lat1
return dlon, dlat
如您所见,这是基于 numpy 的简单代码。 我看了网上的大部分文章,他们说的只是把@numba.jit作为函数前面的装饰器,然后我可以使用Numba来加速我的代码。
这是我做过的测试。
u = np.random.randn(10000)
v = np.random.randn(10000)
lon1 = np.random.uniform(-99,-96,10000)
lat1 = np.random.uniform( 23, 25,10000)
print(u)
%%timeit
for i in range(10000):
distance(u,v,lon1,lat1)
每个循环 5.61 秒 ± 58.7 毫秒(7 次运行的平均值 ± 标准差,每次 1 个循环)
添加 Numba 装饰器
@numba.njit()
def d_lat(dlat,R=6.371*1e6):
return 2 * R * np.sqrt(np.sin(np.deg2rad(dlat)/2)**2)
@numba.njit()
def d_lon(lat1,lat2,dlon,R=6.371*1e6):
return 2 * R * np.sqrt(np.cos(np.deg2rad(lat1)) *
np.cos(np.deg2rad(lat2)) *
np.sin(np.deg2rad(dlon)/2)**2)
@numba.njit()
def distance(u, v, lon1, lat1, R=6.371*1e6):
lat2, lon2 = lat1.copy(), lon1.copy()
lat2[v>0], lat2[v<0], = lat1[v>0]+1, lat1[v<0]-1,
lon2[u>0], lon2[u<0], = lon1[u>0]+1, lon1[u<0]-1,
dlat = lat2 - lat1
dlon = lon2 - lon1
return d_lon(lat1,lat2,dlon), d_lat(dlat)
%%timeit
for i in range(10000):
a,b = distance(u,v,lon1,lat1)
每个循环 7.76 秒 ± 64.9 毫秒(7 次运行的平均值 ± 标准偏差,每次 1 个循环)
正如您在上面看到的,我的 Numba 案例的计算速度比我的纯 python 案例慢。谁能帮我解决这个问题?
ps: numba 版本
llvmlite 0.32.0rc1
numba 0.49.0rc2
------ 关于宏观经济学家答案的计算测试。 ------
根据他的回答,Numba现在已经足够聪明了,如果我们希望代码被Numba装饰,最好使用普通的"Fortran"/"C"类型的样式.下面是我正在考虑的不同方法之间的计算时间比较。
def d_lat(dlat,R=6.371*1e6):
return 2 * R * np.sqrt(np.sin(np.deg2rad(dlat)/2)**2)
def d_lon(lat1,lat2,dlon,R=6.371*1e6):
return 2 * R * np.sqrt(np.cos(np.deg2rad(lat1)) *
np.cos(np.deg2rad(lat2)) *
np.sin(np.deg2rad(dlon)/2)**2)
def distance(u,v,lon1,lat1):
lat2, lon2 = lat1.copy(), lon1.copy()
lat2[v>0], lat2[v<0], = lat1[v>0]+1, lat1[v<0]-1,
lon2[u>0], lon2[u<0], = lon1[u>0]+1, lon1[u<0]-1,
dlon = lon2 - lon1
dlat = lat2 - lat1
return dlon, dlat
%%timeit
for i in range(10000):
distance(u,v,lon1,lat1)
每个循环 54 秒 ± 485 毫秒(7 次运行的平均值 ± 标准差,每次 1 个循环)
@numba.jit(nogil=True)
def d_lat(dlat,R=6.371*1e6):
return 2 * R * np.sqrt(np.sin(np.deg2rad(dlat)/2)**2)
@numba.jit(nogil=True)
def d_lon(lat1,lat2,dlon,R=6.371*1e6):
return 2 * R * np.sqrt(np.cos(np.deg2rad(lat1)) *
np.cos(np.deg2rad(lat2)) *
np.sin(np.deg2rad(dlon)/2)**2)
def distance(u, v, lon1, lat1, R=6.371*1e6):
lat2, lon2 = lat1.copy(), lon1.copy()
lat2[v>0], lat2[v<0], = lat1[v>0]+1, lat1[v<0]-1,
lon2[u>0], lon2[u<0], = lon1[u>0]+1, lon1[u<0]-1,
dlat = lat2 - lat1
dlon = lon2 - lon1
return d_lon(lat1,lat2,dlon), d_lat(dlat)
%%timeit
for i in range(10000):
a,b = distance(u,v,lon1,lat1)
每个循环 1 分钟 21 秒 ± 815 毫秒(7 次运行的平均值 ± 标准偏差,每次 1 个循环)
def d_lat(dlat,R=6.371*1e6):
return 2 * R * np.sqrt(np.sin(np.deg2rad(dlat)/2)**2)
def d_lon(lat1,lat2,dlon,R=6.371*1e6):
return 2 * R * np.sqrt(np.cos(np.deg2rad(lat1)) *
np.cos(np.deg2rad(lat2)) *
np.sin(np.deg2rad(dlon)/2)**2)
@numba.njit(nogil=True)
def distance(u, v, lon1, lat1, R=6.371*1e6):
def d_lat(dlat,R=6.371*1e6):
return 2 * R * np.sqrt(np.sin(np.deg2rad(dlat)/2)**2)
def d_lon(lat1,lat2,dlon,R=6.371*1e6):
return 2 * R * np.sqrt(np.cos(np.deg2rad(lat1)) *
np.cos(np.deg2rad(lat2)) *
np.sin(np.deg2rad(dlon)/2)**2)
lat2, lon2 = lat1.copy(), lon1.copy()
lat2[v>0], lat2[v<0], = lat1[v>0]+1, lat1[v<0]-1,
lon2[u>0], lon2[u<0], = lon1[u>0]+1, lon1[u<0]-1,
dlat = d_lat(lat2 - lat1)
dlon = d_lon(lat1,lat2,lon2 - lon1)
return dlon, dlat
%%timeit
for i in range(10000):
a,b = distance(u,v,lon1,lat1)
1min 2s ± 239 ms per loop (mean ± std.dev. of 7 runs, 1 loop each)
@numba.njit()
def d_lat(dlat,R=6.371*1e6):
return 2 * R * np.sqrt(np.sin(np.deg2rad(dlat)/2)**2)
@numba.njit()
def d_lon(lat1,lat2,dlon,R=6.371*1e6):
return 2 * R * np.sqrt(np.cos(np.deg2rad(lat1)) *
np.cos(np.deg2rad(lat2)) *
np.sin(np.deg2rad(dlon)/2)**2)
@numba.njit()
def distance(u, v, lon1, lat1):
lon2 = np.empty_like(lon1)
lat2 = np.empty_like(lat1)
dlon = np.empty_like(lon1)
dlat = np.empty_like(lat1)
for i in range(len(v)):
vi = v[i]
if vi > 0:
lat2[i] = lat1[i]+1
dlat[i] = 1
elif vi < 0:
lat2[i] = lat1[i]-1
dlat[i] = -1
else:
lat2[i] = lat1[i]
dlat[i] = 0
for i in range(len(u)):
ui = u[i]
if ui > 0:
lon2[i] = lon1[i]+1
dlon[i] = 1
elif ui < 0:
lon2[i] = lon1[i]-1
dlon[i] = -1
else:
lon2[i] = lon1[i]
dlon[i] = 0
return d_lon(lat1,lat2,dlon), d_lat(dlat)
%%timeit
for i in range(10000):
distance(u,v,lon1,lat1)
每个循环 35.9 s ± 537 ms(7 次运行的平均值 ± 标准偏差,每次 1 个循环)
有几个问题跳出来了。
首先,您在 distance
函数中的计算不必要地复杂,并且以一种可能不适合 Numba 编译器的风格编写(有很多花哨的索引,例如 lat2[v>0]
)。尽管 Numba 变得越来越聪明,但我发现以简单的、面向循环的方式编写代码仍然有很高的 return。
其次,Numba 可以通过可选参数稍微减慢速度。我发现这主要适用于 distance
函数中的可选 R
。
解决这两个问题 - 特别是,用更简单的循环替换你的矢量化代码,最大限度地减少操作 - 我们得到
形式的代码@numba.njit()
def d_lat(dlat,R=6.371*1e6):
return 2 * R * np.sqrt(np.sin(np.deg2rad(dlat)/2)**2)
@numba.njit()
def d_lon(lat1,lat2,dlon,R=6.371*1e6):
return 2 * R * np.sqrt(np.cos(np.deg2rad(lat1)) *
np.cos(np.deg2rad(lat2)) *
np.sin(np.deg2rad(dlon)/2)**2)
@numba.njit()
def distance(u, v, lon1, lat1):
lon2 = np.empty_like(lon1)
lat2 = np.empty_like(lat1)
dlon = np.empty_like(lon1)
dlat = np.empty_like(lat1)
for i in range(len(v)):
vi = v[i]
if vi > 0:
lat2[i] = lat1[i]+1
dlat[i] = 1
elif vi < 0:
lat2[i] = lat1[i]-1
dlat[i] = -1
else:
lat2[i] = lat1[i]
dlat[i] = 0
for i in range(len(u)):
ui = u[i]
if ui > 0:
lon2[i] = lon1[i]+1
dlon[i] = 1
elif ui < 0:
lon2[i] = lon1[i]-1
dlon[i] = -1
else:
lon2[i] = lon1[i]
dlon[i] = 0
return d_lon(lat1,lat2,dlon), d_lat(dlat)
在我的(较慢的)系统上,这将编译初始成本后的时间从大约 7 秒减少到大约 4 秒。到那时,我认为成本主要由所有函数的原始成本决定 np.sin
、np.cos
、np.exp
等