为什么 Numba 的 "Eager compilation" 会减慢执行速度

Why Numba's "Eager compilation" slows down the execution

附上一个最小的例子:

from numba import jit
import numba as nb
import numpy as np

@jit(nb.float64[:, :](nb.int32[:, :])) 
def go_fast(a): 
    trace = 0.0
    for i in range(a.shape[0]):  
        trace += np.tanh(a[i, i]) 
    return a + trace          

@jit 
def go_fast2(a): 
    trace = 0.0
    for i in range(a.shape[0]):  
        trace += np.tanh(a[i, i]) 
    return a + trace 

运行 在 Jupyter 中:

x = np.arange(10000).reshape(100, 100)
%timeit go_fast(x)
%timeit go_fast2(x)

导致

每个循环 5.65 µs ± 27.1 ns(7 次运行的平均值 ± 标准偏差,每次 100000 次循环)

每个循环 3.8 µs ± 46.6 ns(7 次运行的平均值 ± 标准偏差,每次 100000 次循环)

为什么急切的编译会导致执行速度变慢?

知道内存访问是连续的可以简化优化器的生命周期(这里是 对于 Cython,但类似的情况适用于 numba,即使 clang 通常比 gcc 更聪明)。

你的例子好像是这样的:

  1. 如果没有“急切的编译”,numba 将检测到数据是 C 连续的并利用它,例如用于矢量化。
  2. 使用 eager compilation,您不提供此信息,因此优化器必须考虑到内存访问可能是非连续的,并且会创建一个性能低于第一个版本的 jit 代码。

因此,您应该提供更准确的签名:

@jit(nb.float64[:, ::1](nb.int32[:, ::1])) 
def go_fast3(a): 
    trace = 0.0
    for i in range(a.shape[0]):  
        trace += np.tanh(a[i, i]) 
    return a + trace

[:,::1] tells numba,数据将是 C 连续的,一旦利用此信息:

x = np.arange(10000).astype(np.int32).reshape(100, 100)
%timeit go_fast(x)     # 15.6 µs ± 241 ns per loop
%timeit go_fast2(x)    # 8.2 µs ± 90.7 ns per loop
%timeit go_fast3(x)    # 8.2 µs ± 49.6 ns per loop

急切编译版本没有区别。