为什么 Numba 会扭曲 JIT 编译函数的时间?

Why does Numba skew the timings of a JIT-compiled function?

我正在尝试对一个 Python 函数进行基准测试,该函数使用 Numba 对 CPython 解释器进行列表操作。为了比较端到端时间,我使用了 Linux 时间实用程序。 time python3.10 list.py

据我所知,由于 JIT 编译,第一次调用会很昂贵,但这并不能解释为什么最大记录时间比 运行 整个脚本花费的总时间长。

# list.py
import numpy as np
from time import time, perf_counter 
from numba import njit

@njit
def listOperations():
  list = []
  for i in range(1000):
    list.append(i)
  
  list.sort(reverse=True)
  list.remove(420)
  list.reverse()

if __name__ == "__main__":
    repetitions = 1000
    timings = np.zeros(repetitions)

    for rep in range(repetitions):
        start = time()  # Similar results with perf_counter too.
        listOperations()
        timings[rep] = time() - start

    # Convert to milliseconds
    timings *= 10e3
    print("Mean {}ms, Median {}ms, Std. Dev {}ms, Min {}ms, Max {}ms".format(
            float('%.4f' % np.mean(timings)), 
            float('%.4f' % np.median(timings)), 
            float('%.4f' % np.std(timings)), 
            float('%.4f' % np.min(timings)), 
            float('%.4f' % np.max(timings)))
    )

对于 Numba,它显示最大值为 ~66.3 秒,而时间实用程序报告为~8 秒。完整结果如下。

'''
Numba --->
Mean 66.8154ms, Median 0.391ms, Std. Dev 2097.7752ms, Min 0.3219ms, Max 66371.1143ms

real  0m7.982s
user  0m8.248s
sys   0m0.100s

CPython3.10 --->
Mean 1.6395ms, Median 1.6284ms, Std. Dev 0.0708ms, Min 1.5759ms, Max 2.3198ms

real. 0m1.115s
user  0m1.468s
sys   0m0.080s 
'''

主要问题是编译时间包含在计时中。确实,Numba compiles the functions lazily。为防止这种情况,您必须 指定原型 或在外部执行第一个函数调用(这在基准测试中通常是一个很好的做法)。

您可以使用@njit('()')代替@njit。通过此修复,Numba 代码在我的机器上大约快两倍。

请注意,您的函数不会 return 任何内容,也不会读取参数中的任何内容,因此 JIT 可以将函数优化为 no-op。为了避免偏差,您当然需要添加一个参数,使用它并添加到 return 列表中。在我的机器上显然不是这种情况,但不同版本的 Numba 可能会这样做。

另请注意,Numba 列表通常不在 Numba 大放异彩的地方。列表通常很慢(使用和不使用 Numba)。 已知大小时使用数组更好

顺便说一下,list 是一个 built-in 函数。覆盖它可能会导致使用它的模块中出现偷偷摸摸的错误(经常),所以这不是一个好主意。我建议你换个名字。

此外,请注意结果中的标准偏差相当大,中值时间很好,最大时间非常大,表明时间不稳定,这种不稳定是由于一次缓慢的调用造成的。这样的结果通常表明基准测试存在缺陷或函数本身具有不稳定的行为(通常是由于错误或一次初始化完成)。