为什么 numba 在我的代码中比 pure python 慢?
Why numba is slower than pure python in my code?
我是 python 的新手,我在玩 numba 并编写了一个 运行 比 numba 中的纯 python 慢的代码。在少量情况下,纯 python 比 numba 快 x4 倍左右,在大量情况下,它们 运行 几乎相同。是什么让我的代码 运行 在 numba 中变慢?
from numba import njit
@njit
def forr (q):
p=0
k=q
n=0
while k!=0:
n += 1
k=k//10
h=(abs(q-n*9)+q-n*9)//2
for j in range(q,h,-1):
s=0
k=j
while k!=0:
s += k%10
k=k//10
if s+j==q:
p=1
print('Yes')
break
if p==0:
print('No')
我认为你的 Numba 代码 运行 变慢的原因是因为接下来的事情:
- 可能您测量了第一个 运行 函数的时间,在第一次 Numba JIT 编译代码时可能需要几秒钟。要获得正确的时间测量,您需要首先单独调用 numba 函数,以便对其进行 JIT 预编译。
- 您可能没有提供足够大的输入(输入数字)因此您的函数只需要很少的时间并且 numba 函数有一些开销来启动。如果可能的话,在你的代码中你应该把相当长的算法放在 Numba 函数中,至少需要几十毫秒到 运行.
- 您可能正在测量几个 运行 秒,您必须在一个循环中测量数百个 运行 秒的函数才能获得更准确的结果。
- 您没有将
cache = True
选项放入 @njit
装饰器中,此选项将有助于在每个脚本 运行 中获取预编译代码,而不是从头开始编译。
- Print function call itself inside functions that take little time 可能会占用相当多的时间,因为控制台操作非常慢。最好 return 函数的结果并将它们打印在 Numba 函数之外。
考虑到以上所有内容,我实施了下一个代码来测量您的 Numba 代码,我只是添加了 cache = True
选项并注释掉了 print()
测量时间调用(不要破坏带有数百个的控制台测量时的话)。
下一段代码显示 Numba 变体在我的笔记本电脑上快 29x
倍。此外,下一段代码需要通过命令 pip install numba timerit
.
安装一次 pip 模块
import timerit, numba
timerit.Timerit._default_asciimode = True
def forr(q):
p=0
k=q
n=0
while k!=0:
n += 1
k=k//10
h=(abs(q-n*9)+q-n*9)//2
for j in range(q,h,-1):
s=0
k=j
while k!=0:
s += k%10
k=k//10
if s+j==q:
p=1
#print('Yes')
break
if p==0:
#print('No')
pass
nforr = numba.njit(cache = True)(forr)
nforr(2) # Heat-up, precompile numba
tb = None
for f in [forr, nforr]:
tim = timerit.Timerit(num = 99, verbose = 1)
for t in tim:
f(1 << 60)
if tb is None:
tb = tim.mean()
else:
print(f'speedup {round(tb / tim.mean(), 1)}x')
输出:
Timed best=1.029 ms, mean=1.040 +- 0.0 ms
Timed best=35.300 us, mean=35.673 +- 0.3 us
speedup 29.2x
我是 python 的新手,我在玩 numba 并编写了一个 运行 比 numba 中的纯 python 慢的代码。在少量情况下,纯 python 比 numba 快 x4 倍左右,在大量情况下,它们 运行 几乎相同。是什么让我的代码 运行 在 numba 中变慢?
from numba import njit
@njit
def forr (q):
p=0
k=q
n=0
while k!=0:
n += 1
k=k//10
h=(abs(q-n*9)+q-n*9)//2
for j in range(q,h,-1):
s=0
k=j
while k!=0:
s += k%10
k=k//10
if s+j==q:
p=1
print('Yes')
break
if p==0:
print('No')
我认为你的 Numba 代码 运行 变慢的原因是因为接下来的事情:
- 可能您测量了第一个 运行 函数的时间,在第一次 Numba JIT 编译代码时可能需要几秒钟。要获得正确的时间测量,您需要首先单独调用 numba 函数,以便对其进行 JIT 预编译。
- 您可能没有提供足够大的输入(输入数字)因此您的函数只需要很少的时间并且 numba 函数有一些开销来启动。如果可能的话,在你的代码中你应该把相当长的算法放在 Numba 函数中,至少需要几十毫秒到 运行.
- 您可能正在测量几个 运行 秒,您必须在一个循环中测量数百个 运行 秒的函数才能获得更准确的结果。
- 您没有将
cache = True
选项放入@njit
装饰器中,此选项将有助于在每个脚本 运行 中获取预编译代码,而不是从头开始编译。 - Print function call itself inside functions that take little time 可能会占用相当多的时间,因为控制台操作非常慢。最好 return 函数的结果并将它们打印在 Numba 函数之外。
考虑到以上所有内容,我实施了下一个代码来测量您的 Numba 代码,我只是添加了 cache = True
选项并注释掉了 print()
测量时间调用(不要破坏带有数百个的控制台测量时的话)。
下一段代码显示 Numba 变体在我的笔记本电脑上快 29x
倍。此外,下一段代码需要通过命令 pip install numba timerit
.
import timerit, numba
timerit.Timerit._default_asciimode = True
def forr(q):
p=0
k=q
n=0
while k!=0:
n += 1
k=k//10
h=(abs(q-n*9)+q-n*9)//2
for j in range(q,h,-1):
s=0
k=j
while k!=0:
s += k%10
k=k//10
if s+j==q:
p=1
#print('Yes')
break
if p==0:
#print('No')
pass
nforr = numba.njit(cache = True)(forr)
nforr(2) # Heat-up, precompile numba
tb = None
for f in [forr, nforr]:
tim = timerit.Timerit(num = 99, verbose = 1)
for t in tim:
f(1 << 60)
if tb is None:
tb = tim.mean()
else:
print(f'speedup {round(tb / tim.mean(), 1)}x')
输出:
Timed best=1.029 ms, mean=1.040 +- 0.0 ms
Timed best=35.300 us, mean=35.673 +- 0.3 us
speedup 29.2x