Numba np.convolve 真的很慢

Numba np.convolve really slow

我正在尝试加速一段代码,该代码将一维数组(过滤器)与二维数组的每一列进行卷积。不知何故,当我 运行 使用 numba 的 njit 时,速度降低了 7 倍。我的想法:

(在 Windows 10、python 3.9.4 来自 conda、numpy 1.12.2、numba 0.53.1 上测试)

谁能告诉我为什么这段代码很慢?

import numpy as np
from numba import njit

def f1(a1, filt):
    l2 = filt.size // 2
    res = np.empty(a1.shape)
    for i in range(a1.shape[1]):
        res[:, i] = np.convolve(a1[:, i], filt)[l2:-l2]
    return res

@njit
def f1_jit(a1, filt):
    l2 = filt.size // 2
    res = np.empty(a1.shape)
    for i in range(a1.shape[1]):
        res[:, i] = np.convolve(a1[:, i], filt)[l2:-l2]
    return res

a1 = np.random.random((6400, 1000))
filt = np.random.random((65))
f1(a1, filt)
f1_jit(a1, filt)

%timeit f1(a1, filt)     # 404 ms ± 19.3 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
%timeit f1_jit(a1, filt) # 2.8 s ± 66.7 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

问题来自 np.convolve 的 Numba 实现。这是一个known issue。事实证明,当前的 Numba 实现比 Numpy(在 Windows 上测试的版本 <=0.54.1)慢得多。


引擎盖下

一方面,应该使用np.dotNumpy implementation call correlate which itself performs a dot product that should be implemented by the fast BLAS library available on your system. On the other hand, the Numba implementation calls _get_inner_prod也使用相同的BLAS库(假设检测到一个BLAS应该是案例)...

也就是说,有多个 与点积相关的问题

首先,如果 numba/np/arraymath.py 的内部变量 _HAVE_BLAS 被手动禁用,Numba 使用点积的 回退实现 明显慢。然而,事实证明,使用 np.convolve 使用的后备点积实现比我机器上的 BLAS 包装器执行速度快 5 倍!在 njit Numba 装饰器中额外使用参数 fastmath=True 可使整体执行速度提高 8.7 倍!这是测试代码:

import numpy as np
import numba as nb

def npConvolve(a, b):
    return np.convolve(a, b)

@nb.njit('float64[:](float64[:], float64[:])')
def nbConvolveUncont(a, b):
    return np.convolve(a, b)

@nb.njit('float64[::1](float64[::1], float64[::1])')
def nbConvolveCont(a, b):
    return np.convolve(a, b)

a = np.random.random(6400)
b = np.random.random(65)
%timeit -n 100 npConvolve(a, b)
%timeit -n 100 nbConvolveUncont(a, b)
%timeit -n 100 nbConvolveCont(a, b)

以下是有趣的原始结果:

With _HAVE_BLAS=True (default):
126 µs ± 292 ns per loop
1.6 ms ± 21.3 µs per loop
1.6 ms ± 18.5 µs per loop

With _HAVE_BLAS=False:
125 µs ± 359 ns per loop
311 µs ± 1.18 µs per loop
268 µs ± 4.26 µs per loop

With _HAVE_BLAS=False and fastmath=True:
125 µs ± 757 ns per loop
327 µs ± 3.69 µs per loop
183 µs ± 654 ns per loop

此外,Numba 的 np_convolve 在内部 翻转一些数组参数 然后使用具有非平凡步幅(即不是 1)的翻转数组执行点积).这样的 不平凡的步幅可能会对点积性能产生影响 。更一般地说,任何阻止编译器知道数组是连续的转换肯定会严重影响性能。实际上,以下测试显示了使用 Numba 的点积实现处理连续数组的影响:

import numpy as np
import numba as nb

def np_dot(a, b):
    return np.dot(a, b)

@nb.njit('float64(float64[::1], float64[::1])')
def nb_dot_cont(a, b):
    return np.dot(a, b)

@nb.njit('float64(float64[::1], float64[:])')
def nb_dot_stride(a, b):
    return np.dot(a, b)

v = np.random.random(128*1024)
%timeit -n 200 np_dot(v, v)         #  36.5 µs ±  4.9 µs per loop
%timeit -n 200 nb_dot_stride(v, v)  # 361.0 µs ± 17.1 µs per loop  (x10 !!!)
%timeit -n 200 nb_dot_cont(v, v)    #  34.1 µs ±  2.9 µs per loop

关于 Numpy 和 Numba 的一些一般说明

请注意,Numba 在处理相当大的数组时几乎无法加速 Numpy 调用,因为 Numba 主要在 Python 中重新实现 Numpy 函数 并使用 JIT 编译器 (LLVM-Lite) 来加速它们,而 Numpy 主要是用纯 C 语言实现的(使用相当慢的 Python 包装代码)。 Numpy 代码使用 SIMD 指令 等低级处理器功能来加快许多功能的执行速度。两者似乎都使用已知高度优化的 BLAS 库。 Numpy 往往更优化,因为 Numpy 目前比 Numba 更成熟:Numpy 有更多的贡献者工作了很长时间。