使用 numba JIT 加速函数时遇到问题
Trouble with speeding up functions with numba JIT
我是 numba 的 jit 新手。对于个人项目,我需要加速类似于下面显示的功能,尽管出于编写独立示例的目的而有所不同。
import numpy as np
from numba import jit, autojit, double, float64, float32, void
def f(n):
k=0.
for i in range(n):
for j in range(n):
k+= i+j
def f_with_return(n):
k=0.
for i in range(n):
for j in range(n):
k+= i+j
return k
def f_with_arange(n):
k=0.
for i in np.arange(n):
for j in np.arange(n):
k+= i+j
def f_with_arange_and_return(n):
k=0.
for i in np.arange(n):
for j in np.arange(n):
k+= i+j
#jit decorators
jit_f = jit(void(int32))(f)
jit_f_with_return = jit(int32(int32))(f_with_return)
jit_f_with_arange = jit(void(double))(f_with_arange)
jit_f_with_arange_and_return = jit(double(double))(f_with_arange_and_return)
和基准:
%timeit f(1000)
%timeit jit_f(1000)
10 个循环,3 个循环中的最佳:每个循环 73.9 毫秒/1000000 个循环,3 个循环中的最佳:每个循环 212 ns
%timeit f_with_return(1000)
%timeit jit_f_with_return(1000)
10 个循环,3 个循环中的最佳:每个循环 74.9 毫秒/1000000 个循环,3 个循环中的最佳:每个循环 220 ns
这两个没看懂:
%timeit f_with_arange(1000.0)
%timeit jit_f_with_arange(1000.0)
10 次循环,3 次最佳:每次循环 175 毫秒/1 次循环,3 次最佳:每次循环 167 毫秒
%timeit f_with_arange_with_return(1000.0)
%timeit jit_f_with_arange_with_return(1000.0)
10 次循环,3 次最佳:每次循环 174 毫秒/1 次循环,3 次最佳:每次循环 172 毫秒
我想我没有为 jit 函数提供正确的输出和输入类型?仅仅因为 for 循环现在在 numpy.arange 之上 运行 而不再是一个简单的范围,我无法让 jit 让它更快。这里有什么问题?
简单地说,numba 不知道如何将 np.arange
转换为低级本机循环,因此它默认为对象层,该对象层要慢得多,通常与纯 python 的速度相同.
一个不错的技巧是将 nopython=True
关键字参数传递给 jit
以查看它是否可以在不诉诸对象模式的情况下编译所有内容:
import numpy as np
import numba as nb
def f_with_return(n):
k=0.
for i in range(n):
for j in range(n):
k+= i+j
return k
jit_f_with_return = nb.jit()(f_with_return)
jit_f_with_return_nopython = nb.jit(nopython=True)(f_with_return)
%timeit f_with_return(1000)
%timeit jit_f_with_return(1000)
%timeit jit_f_with_return_nopython(1000)
最后两个在我的机器上速度相同,比未编译的代码快得多。您有疑问的两个示例将引发 nopython=True
错误,因为此时无法编译 np.arange
。
查看以下内容了解更多详情:
http://numba.pydata.org/numba-doc/0.17.0/user/troubleshoot.html#the-compiled-code-is-too-slow
以及在 nopython
模式下支持和不支持的指示的受支持 numpy 功能列表:
http://numba.pydata.org/numba-doc/0.17.0/reference/numpysupported.html
我是 numba 的 jit 新手。对于个人项目,我需要加速类似于下面显示的功能,尽管出于编写独立示例的目的而有所不同。
import numpy as np
from numba import jit, autojit, double, float64, float32, void
def f(n):
k=0.
for i in range(n):
for j in range(n):
k+= i+j
def f_with_return(n):
k=0.
for i in range(n):
for j in range(n):
k+= i+j
return k
def f_with_arange(n):
k=0.
for i in np.arange(n):
for j in np.arange(n):
k+= i+j
def f_with_arange_and_return(n):
k=0.
for i in np.arange(n):
for j in np.arange(n):
k+= i+j
#jit decorators
jit_f = jit(void(int32))(f)
jit_f_with_return = jit(int32(int32))(f_with_return)
jit_f_with_arange = jit(void(double))(f_with_arange)
jit_f_with_arange_and_return = jit(double(double))(f_with_arange_and_return)
和基准:
%timeit f(1000)
%timeit jit_f(1000)
10 个循环,3 个循环中的最佳:每个循环 73.9 毫秒/1000000 个循环,3 个循环中的最佳:每个循环 212 ns
%timeit f_with_return(1000)
%timeit jit_f_with_return(1000)
10 个循环,3 个循环中的最佳:每个循环 74.9 毫秒/1000000 个循环,3 个循环中的最佳:每个循环 220 ns
这两个没看懂:
%timeit f_with_arange(1000.0)
%timeit jit_f_with_arange(1000.0)
10 次循环,3 次最佳:每次循环 175 毫秒/1 次循环,3 次最佳:每次循环 167 毫秒
%timeit f_with_arange_with_return(1000.0)
%timeit jit_f_with_arange_with_return(1000.0)
10 次循环,3 次最佳:每次循环 174 毫秒/1 次循环,3 次最佳:每次循环 172 毫秒
我想我没有为 jit 函数提供正确的输出和输入类型?仅仅因为 for 循环现在在 numpy.arange 之上 运行 而不再是一个简单的范围,我无法让 jit 让它更快。这里有什么问题?
简单地说,numba 不知道如何将 np.arange
转换为低级本机循环,因此它默认为对象层,该对象层要慢得多,通常与纯 python 的速度相同.
一个不错的技巧是将 nopython=True
关键字参数传递给 jit
以查看它是否可以在不诉诸对象模式的情况下编译所有内容:
import numpy as np
import numba as nb
def f_with_return(n):
k=0.
for i in range(n):
for j in range(n):
k+= i+j
return k
jit_f_with_return = nb.jit()(f_with_return)
jit_f_with_return_nopython = nb.jit(nopython=True)(f_with_return)
%timeit f_with_return(1000)
%timeit jit_f_with_return(1000)
%timeit jit_f_with_return_nopython(1000)
最后两个在我的机器上速度相同,比未编译的代码快得多。您有疑问的两个示例将引发 nopython=True
错误,因为此时无法编译 np.arange
。
查看以下内容了解更多详情:
http://numba.pydata.org/numba-doc/0.17.0/user/troubleshoot.html#the-compiled-code-is-too-slow
以及在 nopython
模式下支持和不支持的指示的受支持 numpy 功能列表:
http://numba.pydata.org/numba-doc/0.17.0/reference/numpysupported.html