为什么 Numba 的 "Eager compilation" 会减慢执行速度
Why Numba's "Eager compilation" slows down the execution
附上一个最小的例子:
from numba import jit
import numba as nb
import numpy as np
@jit(nb.float64[:, :](nb.int32[:, :]))
def go_fast(a):
trace = 0.0
for i in range(a.shape[0]):
trace += np.tanh(a[i, i])
return a + trace
@jit
def go_fast2(a):
trace = 0.0
for i in range(a.shape[0]):
trace += np.tanh(a[i, i])
return a + trace
运行 在 Jupyter 中:
x = np.arange(10000).reshape(100, 100)
%timeit go_fast(x)
%timeit go_fast2(x)
导致
每个循环 5.65 µs ± 27.1 ns(7 次运行的平均值 ± 标准偏差,每次 100000 次循环)
每个循环 3.8 µs ± 46.6 ns(7 次运行的平均值 ± 标准偏差,每次 100000 次循环)
为什么急切的编译会导致执行速度变慢?
知道内存访问是连续的可以简化优化器的生命周期(这里是 对于 Cython,但类似的情况适用于 numba,即使 clang 通常比 gcc 更聪明)。
你的例子好像是这样的:
- 如果没有“急切的编译”,numba 将检测到数据是 C 连续的并利用它,例如用于矢量化。
- 使用 eager compilation,您不提供此信息,因此优化器必须考虑到内存访问可能是非连续的,并且会创建一个性能低于第一个版本的 jit 代码。
因此,您应该提供更准确的签名:
@jit(nb.float64[:, ::1](nb.int32[:, ::1]))
def go_fast3(a):
trace = 0.0
for i in range(a.shape[0]):
trace += np.tanh(a[i, i])
return a + trace
[:,::1]
tells numba,数据将是 C 连续的,一旦利用此信息:
x = np.arange(10000).astype(np.int32).reshape(100, 100)
%timeit go_fast(x) # 15.6 µs ± 241 ns per loop
%timeit go_fast2(x) # 8.2 µs ± 90.7 ns per loop
%timeit go_fast3(x) # 8.2 µs ± 49.6 ns per loop
急切编译版本没有区别。
附上一个最小的例子:
from numba import jit
import numba as nb
import numpy as np
@jit(nb.float64[:, :](nb.int32[:, :]))
def go_fast(a):
trace = 0.0
for i in range(a.shape[0]):
trace += np.tanh(a[i, i])
return a + trace
@jit
def go_fast2(a):
trace = 0.0
for i in range(a.shape[0]):
trace += np.tanh(a[i, i])
return a + trace
运行 在 Jupyter 中:
x = np.arange(10000).reshape(100, 100)
%timeit go_fast(x)
%timeit go_fast2(x)
导致
每个循环 5.65 µs ± 27.1 ns(7 次运行的平均值 ± 标准偏差,每次 100000 次循环)
每个循环 3.8 µs ± 46.6 ns(7 次运行的平均值 ± 标准偏差,每次 100000 次循环)
为什么急切的编译会导致执行速度变慢?
知道内存访问是连续的可以简化优化器的生命周期(这里是
你的例子好像是这样的:
- 如果没有“急切的编译”,numba 将检测到数据是 C 连续的并利用它,例如用于矢量化。
- 使用 eager compilation,您不提供此信息,因此优化器必须考虑到内存访问可能是非连续的,并且会创建一个性能低于第一个版本的 jit 代码。
因此,您应该提供更准确的签名:
@jit(nb.float64[:, ::1](nb.int32[:, ::1]))
def go_fast3(a):
trace = 0.0
for i in range(a.shape[0]):
trace += np.tanh(a[i, i])
return a + trace
[:,::1]
tells numba,数据将是 C 连续的,一旦利用此信息:
x = np.arange(10000).astype(np.int32).reshape(100, 100)
%timeit go_fast(x) # 15.6 µs ± 241 ns per loop
%timeit go_fast2(x) # 8.2 µs ± 90.7 ns per loop
%timeit go_fast3(x) # 8.2 µs ± 49.6 ns per loop
急切编译版本没有区别。