为什么 numba 在我的代码中比 pure python 慢?

Why numba is slower than pure python in my code?

我是 python 的新手,我在玩 numba 并编写了一个 运行 比 numba 中的纯 python 慢的代码。在少量情况下,纯 python 比 numba 快 x4 倍左右,在大量情况下,它们 运行 几乎相同。是什么让我的代码 运行 在 numba 中变慢?

from numba import njit
@njit
def forr (q):
    p=0
    k=q
    n=0
    while k!=0:
            n += 1
            k=k//10
    
    h=(abs(q-n*9)+q-n*9)//2 
    for j in range(q,h,-1):
        
        s=0
        k=j
        while k!=0:
            s += k%10
            k=k//10
        
        if s+j==q:
            p=1
            print('Yes')
            break
    if p==0:
        print('No')

我认为你的 Numba 代码 运行 变慢的原因是因为接下来的事情:

  1. 可能您测量了第一个 运行 函数的时间,在第一次 Numba JIT 编译代码时可能需要几秒钟。要获得正确的时间测量,您需要首先单独调用 numba 函数,以便对其进行 JIT 预编译。
  2. 您可能没有提供足够大的输入(输入数字)因此您的函数只需要很少的时间并且 numba 函数有一些开销来启动。如果可能的话,在你的代码中你应该把相当长的算法放在 Numba 函数中,至少需要几十毫秒到 运行.
  3. 您可能正在测量几个 运行 秒,您必须在一个循环中测量数百个 运行 秒的函数才能获得更准确的结果。
  4. 您没有将 cache = True 选项放入 @njit 装饰器中,此选项将有助于在每个脚本 运行 中获取预编译代码,而不是从头开始编译。
  5. Print function call itself inside functions that take little time 可能会占用相当多的时间,因为控制台操作非常慢。最好 return 函数的结果并将它们打印在 Numba 函数之外。

考虑到以上所有内容,我实施了下一个代码来测量您的 Numba 代码,我只是添加了 cache = True 选项并注释掉了 print() 测量时间调用(不要破坏带有数百个的控制台测量时的话)。

下一段代码显示 Numba 变体在我的笔记本电脑上快 29x 倍。此外,下一段代码需要通过命令 pip install numba timerit.

安装一次 pip 模块

Try it online!

import timerit, numba
timerit.Timerit._default_asciimode = True

def forr(q):
    p=0
    k=q
    n=0
    while k!=0:
            n += 1
            k=k//10
    
    h=(abs(q-n*9)+q-n*9)//2 
    for j in range(q,h,-1):
        
        s=0
        k=j
        while k!=0:
            s += k%10
            k=k//10
        
        if s+j==q:
            p=1
            #print('Yes')
            break
    if p==0:
        #print('No')
        pass
        
nforr = numba.njit(cache = True)(forr)
nforr(2) # Heat-up, precompile numba

tb = None
for f in [forr, nforr]:
    tim = timerit.Timerit(num = 99, verbose = 1)
    for t in tim:
        f(1 << 60)
    if tb is None:
        tb = tim.mean()
    else:
        print(f'speedup {round(tb / tim.mean(), 1)}x')

输出:

Timed best=1.029 ms, mean=1.040 +- 0.0 ms
Timed best=35.300 us, mean=35.673 +- 0.3 us
speedup 29.2x