具有共享非收缩轴的 numpy einsum/tensordot

numpy einsum/tensordot with shared non-contracted axis

假设我有两个数组:

import numpy as np
a = np.random.randn(32, 6, 6, 20, 64, 3, 3)
b = np.random.randn(20, 128, 64, 3, 3)

并希望对最后 3 个轴求和,并保留共享轴。输出维度应该是(32,6,6,20,128)。请注意,此处带有 20 的轴在 ab 中共享。我们称此轴为“组”轴。

这个任务我有两种方法:
第一个只是简单的 einsum:

def method1(a, b):
    return np.einsum('NHWgihw, goihw -> NHWgo', a, b, optimize=True)  # output shape:(32,6,6,20,128)

在第二种方法中,我遍历组维度并使用 einsum/tensordot 计算每个组维度的结果,然后堆叠结果:

def method2(a, b):
    result = []
    for g in range(b.shape[0]): # loop through each group dimension
        # result.append(np.tensordot(a[..., g, :, :, :], b[g, ...], axes=((-3,-2,-1),(-3,-2,-1))))
        result.append(np.einsum('NHWihw, oihw -> NHWo', a[..., g, :, :, :], b[g, ...], optimize=True))  # output shape:(32,6,6,128)
    return np.stack(result, axis=-2)  # output shape:(32,6,6,20,128)

这是我的 jupyter notebook 中两种方法的时间安排:

我们可以看到带循环的第二种方法比第一种方法更快。

我的问题是:

  1. method1怎么这么慢?它不会计算更多的东西。
  2. 有没有不使用循环更高效的方法? (我有点不愿意使用循环,因为它们在 python 中很慢)

感谢您的帮助!

正如@Murali 在评论中指出的那样,method1 效率不高,因为它没有成功使用 BLAS 调用而不是 method2 确实如此。事实上,np.einsummethod1 中非常好,因为它按顺序计算结果,而 method2 并行 中主要是 运行s,这要归功于OpenBLAS(Numpy 在大多数机器上使用)。也就是说,method2 是 sub-optimal,因为它没有完全使用可用内核(部分计算是按顺序完成的)并且似乎没有使用缓存有效率的。在我的 6 核机器上,它几乎没有使用所有内核的 50%。


实施速度更快

加速此计算的一个解决方案是为此编写 highly-optimized Numba 并行代码。

首先,semi-naive 实现是使用许多 for 循环来计算爱因斯坦求和并重塑 input/output 数组,以便 Numba 可以更好地优化代码(例如展开,使用SIMD instructions)。这是结果:

@nb.njit('float64[:,:,:,:,::1](float64[:,:,:,:,:,:,::1], float64[:,:,:,:,::1])')
def compute(a, b):
    sN, sH, sW, sg, si, sh, sw = a.shape
    so = b.shape[1]
    assert b.shape == (sg, so, si, sh, sw)

    ra = a.reshape(sN*sH*sW, sg, si*sh*sw)
    rb = b.reshape(sg, so, si*sh*sw)
    out = np.empty((sN*sH*sW, sg, so), dtype=np.float64)

    for NHW in range(sN*sH*sW):
        for g in range(sg):
            for o in range(so):
                s = 0.0

                # Reduction
                for ihw in range(si*sh*sw):
                    s += ra[NHW, g, ihw] * rb[g, o, ihw]

                out[NHW, g, o] = s

    return out.reshape((sN, sH, sW, sg, so))

请注意,假定输入数组是连续的。如果不是这种情况,请考虑执行复制(与计算相比成本低)。

虽然上面的代码有效,但远非高效。以下是一些可以执行的改进:

  • 运行 最外层 NHW 循环 并行 ;
  • 使用 Numba 标志 fastmath=True。如果输入数据包含 NaN 或 +inf/-inf 等特殊值,则此标志 不安全 。然而,这个标志帮助编译器使用 SIMD 指令生成更快的代码(否则这是不可能的,因为 IEEE-754 floating-point 操作不是关联的);
  • 交换基于 NHW 的循环和基于 g 的循环会产生更好的性能,因为它改进了 cache-locality(rb 更有可能适合 last-level 主流 CPU 的缓存,否则它可能会从 RAM 中获取);
  • 利用register blocking让处理器更好的SIMD计算单元饱和,降低内存层级的压力;
  • 通过拆分基于 o 的循环来利用 tiling,因此 rb 几乎可以完全从 lower-level 缓存(例如 L1 或 L2)中读取。

除最后一项外,所有这些改进都在以下代码中实现:

@nb.njit('float64[:,:,:,:,::1](float64[:,:,:,:,:,:,::1], float64[:,:,:,:,::1])', parallel=True, fastmath=True)
def method3(a, b):
    sN, sH, sW, sg, si, sh, sw = a.shape
    so = b.shape[1]
    assert b.shape == (sg, so, si, sh, sw)

    ra = a.reshape(sN*sH*sW, sg, si*sh*sw)
    rb = b.reshape(sg, so, si*sh*sw)
    out = np.zeros((sN*sH*sW, sg, so), dtype=np.float64)

    for g in range(sg):
        for k in nb.prange((sN*sH*sW)//2):
            NHW = k*2
            so_vect_max = (so // 4) * 4

            for o in range(0, so_vect_max, 4):
                s00 = s01 = s02 = s03 = s10 = s11 = s12 = s13 = 0.0

                # Useful since Numba does not optimize well the following loop otherwise
                ra_row0 = ra[NHW+0, g, :]
                ra_row1 = ra[NHW+1, g, :]
                rb_row0 = rb[g, o+0, :]
                rb_row1 = rb[g, o+1, :]
                rb_row2 = rb[g, o+2, :]
                rb_row3 = rb[g, o+3, :]

                # Highly-optimized reduction using register blocking
                for ihw in range(si*sh*sw):
                    ra_0 = ra_row0[ihw]
                    ra_1 = ra_row1[ihw]
                    rb_0 = rb_row0[ihw]
                    rb_1 = rb_row1[ihw]
                    rb_2 = rb_row2[ihw]
                    rb_3 = rb_row3[ihw]
                    s00 += ra_0 * rb_0; s01 += ra_0 * rb_1
                    s02 += ra_0 * rb_2; s03 += ra_0 * rb_3
                    s10 += ra_1 * rb_0; s11 += ra_1 * rb_1
                    s12 += ra_1 * rb_2; s13 += ra_1 * rb_3

                out[NHW+0, g, o+0] = s00; out[NHW+0, g, o+1] = s01
                out[NHW+0, g, o+2] = s02; out[NHW+0, g, o+3] = s03
                out[NHW+1, g, o+0] = s10; out[NHW+1, g, o+1] = s11
                out[NHW+1, g, o+2] = s12; out[NHW+1, g, o+3] = s13

            # Remaining part for `o`
            for o in range(so_vect_max, so):
                for ihw in range(si*sh*sw):
                    out[NHW, g, o] += ra[NHW, g, ihw] * rb[g, o, ihw]
                    out[NHW+1, g, o] += ra[NHW+1, g, ihw] * rb[g, o, ihw]

        # Remaining part for `k`
        if (sN*sH*sW) % 2 == 1:
            k = sN*sH*sW - 1
            for o in range(so):
                for ihw in range(si*sh*sw):
                    out[k, g, o] += ra[k, g, ihw] * rb[g, o, ihw]


    return out.reshape((sN, sH, sW, sg, so))

此代码更加复杂和丑陋,但也更加高效。我没有实施平铺优化,因为它会使代码的可读性更差。但是,它应该会在 many-core 处理器(尤其是具有较小 L2/L3 缓存的处理器)上产生明显更快的代码。


性能结果

以下是我的 i5-9600KF 6 核处理器的性能结果:

method1:              816 ms
method2:              104 ms
method3:               40 ms
Theoretical optimal:    9 ms   (optimistic lower bound)

该代码比 method2 快约 2.7。由于最佳时间比 method3.

快约 4 倍,因此还有改进的余地

Numba 无法生成快速代码的主要原因是底层 JIT 无法有效地向量化循环。实施平铺策略应该会略微提高执行时间,使其非常接近最佳时间。平铺策略对于更大的阵列至关重要。如果 so 大得多,则尤其如此。

如果您想要更快的实现,您当然需要直接使用 SIMD 指令(遗憾的是不可移植)或 SIMD 库(例如 XSIMD)编写 C/C++ 本机代码。

如果您想要更快的实现,那么您需要使用更快的硬件(具有更多内核)或更专用的硬件。 Server-based GPU(即不是个人电脑)不应该能够加速很多这样的计算,因为你的输入很小,显然 compute-bound并大量使用 FMA floating-point 操作。第一步是尝试 cupy.einsum.


引擎盖下:low-level 分析

为了理解为什么method1不快,我检查了执行的代码。这是主循环:

1a0:┌─→; Part of the reduction (see below)
    │  movapd     xmm0,XMMWORD PTR [rdi-0x1000]
    │  
    │  ; Decrement the number of loop cycle
    │  sub        r9,0x8 
    │  
    │  ; Prefetch items so to reduce the impact 
    │  ; of the latency of reading from the RAM.
    │  prefetcht0 BYTE PTR [r8]
    │  prefetcht0 BYTE PTR [rdi]
    │  
    │  ; Part of the reduction (see below)
    │  mulpd      xmm0,XMMWORD PTR [r8-0x1000]
    │  
    │  ; Increment iterator for the two arrays
    │  add        rdi,0x40 
    │  add        r8,0x40 
    │  
    │  ; Main computational part: 
    │  ; reduction using add+mul SSE2 instructions
    │  addpd      xmm1,xmm0                     <--- Slow
    │  movapd     xmm0,XMMWORD PTR [rdi-0x1030]
    │  mulpd      xmm0,XMMWORD PTR [r8-0x1030]
    │  addpd      xmm1,xmm0                     <--- Slow
    │  movapd     xmm0,XMMWORD PTR [rdi-0x1020]
    │  mulpd      xmm0,XMMWORD PTR [r8-0x1020]
    │  addpd      xmm0,xmm1                     <--- Slow
    │  movapd     xmm1,XMMWORD PTR [rdi-0x1010]
    │  mulpd      xmm1,XMMWORD PTR [r8-0x1010]
    │  addpd      xmm1,xmm0                     <--- Slow
    │  
    │  ; Is the loop over? 
    │  ; If not, jump to the beginning of the loop.
    ├──cmp        r9,0x7 
    └──jg         1a0

事实证明,Numpy 使用 SSE2 指令集(在所有 x86-64 处理器上都可用)。然而,我的机器,就像几乎所有相对较新的处理器一样,都支持 AVX 指令集,它可以每条指令一次计算两倍以上的项目。我的机器还支持 fuse-multiply 添加指令 (FMA),在这种情况下速度是原来的两倍。此外,循环显然受 addpd 的限制,它在大部分相同的寄存器中累积结果。处理器无法有效地执行它们,因为 addpd 需要很少的延迟周期,并且在现代 x86-64 处理器上最多可以同时执行两个(这在这里是不可能的,因为只有 1 个指令可以在xmm1 一次)。

这里是method2主要计算部分的执行代码(dgemm调用OpenBLAS):

6a40:┌─→vbroadcastsd ymm0,QWORD PTR [rsi-0x60]
     │  vbroadcastsd ymm1,QWORD PTR [rsi-0x58]
     │  vbroadcastsd ymm2,QWORD PTR [rsi-0x50]
     │  vbroadcastsd ymm3,QWORD PTR [rsi-0x48]
     │  vfmadd231pd  ymm4,ymm0,YMMWORD PTR [rdi-0x80]
     │  vfmadd231pd  ymm5,ymm1,YMMWORD PTR [rdi-0x60]
     │  vbroadcastsd ymm0,QWORD PTR [rsi-0x40]
     │  vbroadcastsd ymm1,QWORD PTR [rsi-0x38]
     │  vfmadd231pd  ymm6,ymm2,YMMWORD PTR [rdi-0x40]
     │  vfmadd231pd  ymm7,ymm3,YMMWORD PTR [rdi-0x20]
     │  vbroadcastsd ymm2,QWORD PTR [rsi-0x30]
     │  vbroadcastsd ymm3,QWORD PTR [rsi-0x28]
     │  vfmadd231pd  ymm4,ymm0,YMMWORD PTR [rdi]
     │  vfmadd231pd  ymm5,ymm1,YMMWORD PTR [rdi+0x20]
     │  vfmadd231pd  ymm6,ymm2,YMMWORD PTR [rdi+0x40]
     │  vfmadd231pd  ymm7,ymm3,YMMWORD PTR [rdi+0x60]
     │  add          rsi,0x40
     │  add          rdi,0x100
     ├──dec          rax
     └──jne          6a40

这个循环要优化得多:它利用了 AVX 指令集和 FMA 指令集(即。vfmadd231pd 说明)。此外,循环展开得更好,并且没有像 Numpy 代码中那样的 latency/dependency 问题。然而,虽然这个循环是 highly-efficient,但由于在 Numpy 中完成了一些顺序检查并且在 OpenBLAS 中执行了顺序复制,因此核心没有得到有效使用。此外,我不确定在这种情况下循环是否有效地使用了缓存,因为很多 read/writes 是在我机器的 RAM 中执行的。实际上,由于许多缓存未命中,RAM 吞吐量约为 15 GiB/s(超过 35~40 GiB/s),而 method3 的吞吐量为 6 GiB/s(因此完成了更多工作在缓存中)执行速度明显加快。

下面是method3主要计算部分的执行代码:

.LBB0_5:
    vorpd   2880(%rsp), %ymm8, %ymm0
    vpcmpeqd    %ymm1, %ymm1, %ymm1
    vgatherqpd  %ymm1, (%rsi,%ymm0,8), %ymm2
    vmovupd %ymm2, 3040(%rsp)
    vorpd   2848(%rsp), %ymm8, %ymm1
    vpcmpeqd    %ymm2, %ymm2, %ymm2
    vgatherqpd  %ymm2, (%rsi,%ymm1,8), %ymm3
    vmovupd %ymm3, 3104(%rsp)
    vorpd   2912(%rsp), %ymm8, %ymm2
    vpcmpeqd    %ymm3, %ymm3, %ymm3
    vgatherqpd  %ymm3, (%rsi,%ymm2,8), %ymm4
    vmovupd %ymm4, 3136(%rsp)
    vorpd   2816(%rsp), %ymm8, %ymm3
    vpcmpeqd    %ymm4, %ymm4, %ymm4
    vgatherqpd  %ymm4, (%rsi,%ymm3,8), %ymm5
    vmovupd %ymm5, 3808(%rsp)
    vorpd   2784(%rsp), %ymm8, %ymm9
    vpcmpeqd    %ymm4, %ymm4, %ymm4
    vgatherqpd  %ymm4, (%rsi,%ymm9,8), %ymm5
    vmovupd %ymm5, 3840(%rsp)
    vorpd   2752(%rsp), %ymm8, %ymm10
    vpcmpeqd    %ymm4, %ymm4, %ymm4
    vgatherqpd  %ymm4, (%rsi,%ymm10,8), %ymm5
    vmovupd %ymm5, 3872(%rsp)
    vpaddq  2944(%rsp), %ymm8, %ymm4
    vorpd   2720(%rsp), %ymm8, %ymm11
    vpcmpeqd    %ymm13, %ymm13, %ymm13
    vgatherqpd  %ymm13, (%rsi,%ymm11,8), %ymm5
    vmovupd %ymm5, 3904(%rsp)
    vpcmpeqd    %ymm13, %ymm13, %ymm13
    vgatherqpd  %ymm13, (%rdx,%ymm0,8), %ymm5
    vmovupd %ymm5, 3552(%rsp)
    vpcmpeqd    %ymm0, %ymm0, %ymm0
    vgatherqpd  %ymm0, (%rdx,%ymm1,8), %ymm5
    vmovupd %ymm5, 3616(%rsp)
    vpcmpeqd    %ymm0, %ymm0, %ymm0
    vgatherqpd  %ymm0, (%rdx,%ymm2,8), %ymm1
    vmovupd %ymm1, 3648(%rsp)
    vpcmpeqd    %ymm0, %ymm0, %ymm0
    vgatherqpd  %ymm0, (%rdx,%ymm3,8), %ymm1
    vmovupd %ymm1, 3680(%rsp)
    vpcmpeqd    %ymm0, %ymm0, %ymm0
    vgatherqpd  %ymm0, (%rdx,%ymm9,8), %ymm1
    vmovupd %ymm1, 3712(%rsp)
    vpcmpeqd    %ymm0, %ymm0, %ymm0
    vgatherqpd  %ymm0, (%rdx,%ymm10,8), %ymm1
    vmovupd %ymm1, 3744(%rsp)
    vpcmpeqd    %ymm0, %ymm0, %ymm0
    vgatherqpd  %ymm0, (%rdx,%ymm11,8), %ymm1
    vmovupd %ymm1, 3776(%rsp)
    vpcmpeqd    %ymm0, %ymm0, %ymm0
    vgatherqpd  %ymm0, (%rsi,%ymm4,8), %ymm6
    vpcmpeqd    %ymm0, %ymm0, %ymm0
    vgatherqpd  %ymm0, (%rdx,%ymm4,8), %ymm3
    vpaddq  2688(%rsp), %ymm8, %ymm0
    vpcmpeqd    %ymm1, %ymm1, %ymm1
    vgatherqpd  %ymm1, (%rsi,%ymm0,8), %ymm7
    vpcmpeqd    %ymm1, %ymm1, %ymm1
    vgatherqpd  %ymm1, (%rdx,%ymm0,8), %ymm4
    vmovupd %ymm4, 3360(%rsp)
    vpaddq  2656(%rsp), %ymm8, %ymm0
    vpcmpeqd    %ymm1, %ymm1, %ymm1
    vgatherqpd  %ymm1, (%rsi,%ymm0,8), %ymm13
    vpcmpeqd    %ymm1, %ymm1, %ymm1
    vgatherqpd  %ymm1, (%rdx,%ymm0,8), %ymm4
    vmovupd %ymm4, 3392(%rsp)
    vpaddq  2624(%rsp), %ymm8, %ymm0
    vpcmpeqd    %ymm1, %ymm1, %ymm1
    vgatherqpd  %ymm1, (%rsi,%ymm0,8), %ymm15
    vpcmpeqd    %ymm1, %ymm1, %ymm1
    vgatherqpd  %ymm1, (%rdx,%ymm0,8), %ymm4
    vmovupd %ymm4, 3424(%rsp)
    vpaddq  2592(%rsp), %ymm8, %ymm0
    vpcmpeqd    %ymm1, %ymm1, %ymm1
    vgatherqpd  %ymm1, (%rsi,%ymm0,8), %ymm9
    vpcmpeqd    %ymm1, %ymm1, %ymm1
    vgatherqpd  %ymm1, (%rdx,%ymm0,8), %ymm4
    vmovupd %ymm4, 3456(%rsp)
    vpaddq  2560(%rsp), %ymm8, %ymm0
    vpcmpeqd    %ymm1, %ymm1, %ymm1
    vgatherqpd  %ymm1, (%rsi,%ymm0,8), %ymm14
    vpcmpeqd    %ymm1, %ymm1, %ymm1
    vgatherqpd  %ymm1, (%rdx,%ymm0,8), %ymm4
    vmovupd %ymm4, 3488(%rsp)
    vpaddq  2528(%rsp), %ymm8, %ymm0
    vpcmpeqd    %ymm1, %ymm1, %ymm1
    vgatherqpd  %ymm1, (%rsi,%ymm0,8), %ymm11
    vpcmpeqd    %ymm1, %ymm1, %ymm1
    vgatherqpd  %ymm1, (%rdx,%ymm0,8), %ymm4
    vmovupd %ymm4, 3520(%rsp)
    vpaddq  2496(%rsp), %ymm8, %ymm0
    vpcmpeqd    %ymm1, %ymm1, %ymm1
    vgatherqpd  %ymm1, (%rsi,%ymm0,8), %ymm10
    vpcmpeqd    %ymm1, %ymm1, %ymm1
    vgatherqpd  %ymm1, (%rdx,%ymm0,8), %ymm4
    vmovupd %ymm4, 3584(%rsp)
    vpaddq  2464(%rsp), %ymm8, %ymm0
    vpcmpeqd    %ymm1, %ymm1, %ymm1
    vgatherqpd  %ymm1, (%rdx,%ymm0,8), %ymm2
    vpaddq  2432(%rsp), %ymm8, %ymm0
    vpcmpeqd    %ymm1, %ymm1, %ymm1
    vgatherqpd  %ymm1, (%rdx,%ymm0,8), %ymm12
    vpaddq  2400(%rsp), %ymm8, %ymm0
    vpcmpeqd    %ymm1, %ymm1, %ymm1
    vgatherqpd  %ymm1, (%rdx,%ymm0,8), %ymm4
    vmovupd %ymm4, 3168(%rsp)
    vpaddq  2368(%rsp), %ymm8, %ymm0
    vpcmpeqd    %ymm1, %ymm1, %ymm1
    vgatherqpd  %ymm1, (%rdx,%ymm0,8), %ymm4
    vmovupd %ymm4, 3200(%rsp)
    vpaddq  2336(%rsp), %ymm8, %ymm0
    vpcmpeqd    %ymm1, %ymm1, %ymm1
    vgatherqpd  %ymm1, (%rdx,%ymm0,8), %ymm4
    vmovupd %ymm4, 3232(%rsp)
    vpaddq  2304(%rsp), %ymm8, %ymm0
    vpcmpeqd    %ymm1, %ymm1, %ymm1
    vgatherqpd  %ymm1, (%rdx,%ymm0,8), %ymm4
    vmovupd %ymm4, 3264(%rsp)
    vpaddq  2272(%rsp), %ymm8, %ymm0
    vpcmpeqd    %ymm1, %ymm1, %ymm1
    vgatherqpd  %ymm1, (%rdx,%ymm0,8), %ymm4
    vmovupd %ymm4, 3296(%rsp)
    vpaddq  2240(%rsp), %ymm8, %ymm0
    vpcmpeqd    %ymm1, %ymm1, %ymm1
    vgatherqpd  %ymm1, (%rdx,%ymm0,8), %ymm4
    vmovupd %ymm4, 3328(%rsp)
    vpaddq  2208(%rsp), %ymm8, %ymm0
    vpcmpeqd    %ymm1, %ymm1, %ymm1
    vgatherqpd  %ymm1, (%rdx,%ymm0,8), %ymm4
    vpaddq  2176(%rsp), %ymm8, %ymm0
    vpcmpeqd    %ymm1, %ymm1, %ymm1
    vgatherqpd  %ymm1, (%rdx,%ymm0,8), %ymm5
    vmovupd %ymm5, 2976(%rsp)
    vpaddq  2144(%rsp), %ymm8, %ymm0
    vpcmpeqd    %ymm1, %ymm1, %ymm1
    vgatherqpd  %ymm1, (%rdx,%ymm0,8), %ymm5
    vmovupd %ymm5, 3008(%rsp)
    vpaddq  2112(%rsp), %ymm8, %ymm0
    vpcmpeqd    %ymm1, %ymm1, %ymm1
    vgatherqpd  %ymm1, (%rdx,%ymm0,8), %ymm5
    vmovupd %ymm5, 3072(%rsp)
    vpcmpeqd    %ymm1, %ymm1, %ymm1
    vgatherqpd  %ymm1, (%rsi,%ymm8,8), %ymm0
    vpcmpeqd    %ymm5, %ymm5, %ymm5
    vgatherqpd  %ymm5, (%rdx,%ymm8,8), %ymm1
    vmovupd 768(%rsp), %ymm5
    vfmadd231pd %ymm0, %ymm1, %ymm5
    vmovupd %ymm5, 768(%rsp)
    vmovupd 32(%rsp), %ymm5
    vfmadd231pd %ymm0, %ymm3, %ymm5
    vmovupd %ymm5, 32(%rsp)
    vmovupd 1024(%rsp), %ymm5
    vfmadd231pd %ymm0, %ymm2, %ymm5
    vmovupd %ymm5, 1024(%rsp)
    vmovupd 1280(%rsp), %ymm5
    vfmadd231pd %ymm0, %ymm4, %ymm5
    vmovupd %ymm5, 1280(%rsp)
    vmovupd 1344(%rsp), %ymm0
    vfmadd231pd %ymm1, %ymm6, %ymm0
    vmovupd %ymm0, 1344(%rsp)
    vmovupd 480(%rsp), %ymm0
    vfmadd231pd %ymm3, %ymm6, %ymm0
    vmovupd %ymm0, 480(%rsp)
    vmovupd 1600(%rsp), %ymm0
    vfmadd231pd %ymm2, %ymm6, %ymm0
    vmovupd %ymm0, 1600(%rsp)
    vmovupd 1856(%rsp), %ymm0
    vfmadd231pd %ymm4, %ymm6, %ymm0
    vmovupd %ymm0, 1856(%rsp)
    vpaddq  2080(%rsp), %ymm8, %ymm0
    vpcmpeqd    %ymm1, %ymm1, %ymm1
    vgatherqpd  %ymm1, (%rdx,%ymm0,8), %ymm2
    vpaddq  2048(%rsp), %ymm8, %ymm0
    vpcmpeqd    %ymm1, %ymm1, %ymm1
    vgatherqpd  %ymm1, (%rdx,%ymm0,8), %ymm4
    vmovupd 800(%rsp), %ymm0
    vmovupd 3552(%rsp), %ymm1
    vmovupd 3040(%rsp), %ymm3
    vfmadd231pd %ymm3, %ymm1, %ymm0
    vmovupd %ymm0, 800(%rsp)
    vmovupd 64(%rsp), %ymm0
    vmovupd 3360(%rsp), %ymm5
    vfmadd231pd %ymm3, %ymm5, %ymm0
    vmovupd %ymm0, 64(%rsp)
    vmovupd 1056(%rsp), %ymm0
    vfmadd231pd %ymm3, %ymm12, %ymm0
    vmovupd %ymm0, 1056(%rsp)
    vmovupd 288(%rsp), %ymm0
    vmovupd 2976(%rsp), %ymm6
    vfmadd231pd %ymm3, %ymm6, %ymm0
    vmovupd %ymm0, 288(%rsp)
    vmovupd 1376(%rsp), %ymm0
    vfmadd231pd %ymm1, %ymm7, %ymm0
    vmovupd %ymm0, 1376(%rsp)
    vmovupd 512(%rsp), %ymm0
    vfmadd231pd %ymm5, %ymm7, %ymm0
    vmovupd %ymm0, 512(%rsp)
    vmovupd 1632(%rsp), %ymm0
    vfmadd231pd %ymm12, %ymm7, %ymm0
    vmovupd %ymm0, 1632(%rsp)
    vmovupd 1888(%rsp), %ymm0
    vfmadd231pd %ymm6, %ymm7, %ymm0
    vmovupd %ymm0, 1888(%rsp)
    vmovupd 832(%rsp), %ymm0
    vmovupd 3616(%rsp), %ymm1
    vmovupd 3104(%rsp), %ymm6
    vfmadd231pd %ymm6, %ymm1, %ymm0
    vmovupd %ymm0, 832(%rsp)
    vmovupd 96(%rsp), %ymm0
    vmovupd 3392(%rsp), %ymm3
    vfmadd231pd %ymm6, %ymm3, %ymm0
    vmovupd %ymm0, 96(%rsp)
    vmovupd 1088(%rsp), %ymm0
    vmovupd 3168(%rsp), %ymm5
    vfmadd231pd %ymm6, %ymm5, %ymm0
    vmovupd %ymm0, 1088(%rsp)
    vmovupd 320(%rsp), %ymm0
    vmovupd 3008(%rsp), %ymm7
    vfmadd231pd %ymm6, %ymm7, %ymm0
    vmovupd %ymm0, 320(%rsp)
    vmovupd 1408(%rsp), %ymm0
    vfmadd231pd %ymm1, %ymm13, %ymm0
    vmovupd %ymm0, 1408(%rsp)
    vmovupd 544(%rsp), %ymm0
    vfmadd231pd %ymm3, %ymm13, %ymm0
    vmovupd %ymm0, 544(%rsp)
    vmovupd 1664(%rsp), %ymm0
    vfmadd231pd %ymm5, %ymm13, %ymm0
    vmovupd %ymm0, 1664(%rsp)
    vmovupd 1920(%rsp), %ymm0
    vfmadd231pd %ymm7, %ymm13, %ymm0
    vmovupd %ymm0, 1920(%rsp)
    vpaddq  2016(%rsp), %ymm8, %ymm0
    vpcmpeqd    %ymm1, %ymm1, %ymm1
    vgatherqpd  %ymm1, (%rdx,%ymm0,8), %ymm3
    vmovupd 864(%rsp), %ymm0
    vmovupd 3648(%rsp), %ymm1
    vmovupd 3136(%rsp), %ymm6
    vfmadd231pd %ymm6, %ymm1, %ymm0
    vmovupd %ymm0, 864(%rsp)
    vmovupd 128(%rsp), %ymm0
    vmovupd 3424(%rsp), %ymm5
    vfmadd231pd %ymm6, %ymm5, %ymm0
    vmovupd %ymm0, 128(%rsp)
    vmovupd 1120(%rsp), %ymm0
    vmovupd 3200(%rsp), %ymm7
    vfmadd231pd %ymm6, %ymm7, %ymm0
    vmovupd %ymm0, 1120(%rsp)
    vmovupd 352(%rsp), %ymm0
    vmovupd 3072(%rsp), %ymm12
    vfmadd231pd %ymm6, %ymm12, %ymm0
    vmovupd %ymm0, 352(%rsp)
    vmovupd 1440(%rsp), %ymm0
    vfmadd231pd %ymm1, %ymm15, %ymm0
    vmovupd %ymm0, 1440(%rsp)
    vmovupd 576(%rsp), %ymm0
    vfmadd231pd %ymm5, %ymm15, %ymm0
    vmovupd %ymm0, 576(%rsp)
    vmovupd 1696(%rsp), %ymm0
    vfmadd231pd %ymm7, %ymm15, %ymm0
    vmovupd %ymm0, 1696(%rsp)
    vmovupd 736(%rsp), %ymm0
    vfmadd231pd %ymm12, %ymm15, %ymm0
    vmovupd %ymm0, 736(%rsp)
    vmovupd 896(%rsp), %ymm0
    vmovupd 3808(%rsp), %ymm1
    vmovupd 3680(%rsp), %ymm5
    vfmadd231pd %ymm1, %ymm5, %ymm0
    vmovupd %ymm0, 896(%rsp)
    vmovupd 160(%rsp), %ymm0
    vmovupd 3456(%rsp), %ymm6
    vfmadd231pd %ymm1, %ymm6, %ymm0
    vmovupd %ymm0, 160(%rsp)
    vmovupd 1152(%rsp), %ymm0
    vmovupd 3232(%rsp), %ymm7
    vfmadd231pd %ymm1, %ymm7, %ymm0
    vmovupd %ymm0, 1152(%rsp)
    vmovupd 384(%rsp), %ymm0
    vfmadd231pd %ymm1, %ymm2, %ymm0
    vmovupd %ymm0, 384(%rsp)
    vmovupd 1472(%rsp), %ymm0
    vfmadd231pd %ymm5, %ymm9, %ymm0
    vmovupd %ymm0, 1472(%rsp)
    vmovupd 608(%rsp), %ymm0
    vfmadd231pd %ymm6, %ymm9, %ymm0
    vmovupd %ymm0, 608(%rsp)
    vmovupd 1728(%rsp), %ymm0
    vfmadd231pd %ymm7, %ymm9, %ymm0
    vmovupd %ymm0, 1728(%rsp)
    vmovupd -128(%rsp), %ymm0
    vfmadd231pd %ymm2, %ymm9, %ymm0
    vmovupd %ymm0, -128(%rsp)
    vmovupd 928(%rsp), %ymm0
    vmovupd 3840(%rsp), %ymm1
    vmovupd 3712(%rsp), %ymm2
    vfmadd231pd %ymm1, %ymm2, %ymm0
    vmovupd %ymm0, 928(%rsp)
    vmovupd 192(%rsp), %ymm0
    vmovupd 3488(%rsp), %ymm5
    vfmadd231pd %ymm1, %ymm5, %ymm0
    vmovupd %ymm0, 192(%rsp)
    vmovupd 1184(%rsp), %ymm0
    vmovupd 3264(%rsp), %ymm6
    vfmadd231pd %ymm1, %ymm6, %ymm0
    vmovupd %ymm0, 1184(%rsp)
    vmovupd 416(%rsp), %ymm0
    vfmadd231pd %ymm1, %ymm4, %ymm0
    vmovupd %ymm0, 416(%rsp)
    vmovupd 1504(%rsp), %ymm0
    vfmadd231pd %ymm2, %ymm14, %ymm0
    vmovupd %ymm0, 1504(%rsp)
    vmovupd 640(%rsp), %ymm0
    vfmadd231pd %ymm5, %ymm14, %ymm0
    vmovupd %ymm0, 640(%rsp)
    vmovupd 1760(%rsp), %ymm0
    vfmadd231pd %ymm6, %ymm14, %ymm0
    vmovupd %ymm0, 1760(%rsp)
    vmovupd -96(%rsp), %ymm0
    vfmadd231pd %ymm4, %ymm14, %ymm0
    vmovupd %ymm0, -96(%rsp)
    vpaddq  1984(%rsp), %ymm8, %ymm0
    vpcmpeqd    %ymm1, %ymm1, %ymm1
    vgatherqpd  %ymm1, (%rdx,%ymm0,8), %ymm2
    vmovupd 960(%rsp), %ymm0
    vmovupd 3872(%rsp), %ymm1
    vmovupd 3744(%rsp), %ymm4
    vfmadd231pd %ymm1, %ymm4, %ymm0
    vmovupd %ymm0, 960(%rsp)
    vmovupd 224(%rsp), %ymm0
    vmovupd 3520(%rsp), %ymm5
    vfmadd231pd %ymm1, %ymm5, %ymm0
    vmovupd %ymm0, 224(%rsp)
    vmovupd 1216(%rsp), %ymm0
    vmovupd 3296(%rsp), %ymm6
    vfmadd231pd %ymm1, %ymm6, %ymm0
    vmovupd %ymm0, 1216(%rsp)
    vmovupd 448(%rsp), %ymm0
    vfmadd231pd %ymm1, %ymm3, %ymm0
    vmovupd %ymm0, 448(%rsp)
    vmovupd 1536(%rsp), %ymm0
    vfmadd231pd %ymm4, %ymm11, %ymm0
    vmovupd %ymm0, 1536(%rsp)
    vmovupd 672(%rsp), %ymm0
    vfmadd231pd %ymm5, %ymm11, %ymm0
    vmovupd %ymm0, 672(%rsp)
    vmovupd 1792(%rsp), %ymm0
    vfmadd231pd %ymm6, %ymm11, %ymm0
    vmovupd %ymm0, 1792(%rsp)
    vmovupd -64(%rsp), %ymm0
    vfmadd231pd %ymm3, %ymm11, %ymm0
    vmovupd %ymm0, -64(%rsp)
    vmovupd 992(%rsp), %ymm0
    vmovupd 3904(%rsp), %ymm1
    vmovupd 3776(%rsp), %ymm3
    vfmadd231pd %ymm1, %ymm3, %ymm0
    vmovupd %ymm0, 992(%rsp)
    vmovupd 256(%rsp), %ymm0
    vmovupd 3584(%rsp), %ymm4
    vfmadd231pd %ymm1, %ymm4, %ymm0
    vmovupd %ymm0, 256(%rsp)
    vmovupd 1248(%rsp), %ymm0
    vmovupd 3328(%rsp), %ymm5
    vfmadd231pd %ymm1, %ymm5, %ymm0
    vmovupd %ymm0, 1248(%rsp)
    vmovupd 1312(%rsp), %ymm0
    vfmadd231pd %ymm1, %ymm2, %ymm0
    vmovupd %ymm0, 1312(%rsp)
    vmovupd 1568(%rsp), %ymm0
    vfmadd231pd %ymm3, %ymm10, %ymm0
    vmovupd %ymm0, 1568(%rsp)
    vmovupd 704(%rsp), %ymm0
    vfmadd231pd %ymm4, %ymm10, %ymm0
    vmovupd %ymm0, 704(%rsp)
    vmovupd 1824(%rsp), %ymm0
    vfmadd231pd %ymm5, %ymm10, %ymm0
    vmovupd %ymm0, 1824(%rsp)
    vmovupd -32(%rsp), %ymm0
    vfmadd231pd %ymm2, %ymm10, %ymm0
    vmovupd %ymm0, -32(%rsp)
    vpaddq  1952(%rsp), %ymm8, %ymm8
    addq    $-4, %rcx
    jne .LBB0_5

循环很大,显然没有正确矢量化:有很多完全无用的指令,内存加载似乎不连续(参见 vgatherqpd)。 Numba 不会生成好的代码,因为底层 JIT (LLVM-Lite) 无法有效地矢量化代码。事实上,我发现类似的 C++ 代码在 simplified example 上被 Clang 13.0 严重矢量化(GCC 和 ICC 在更复杂的代码上也失败),而 hand-written SIMD 实现工作得更好。它看起来像是优化器的错误,或者至少是错过了优化。这就是 Numba 代码比最优代码慢得多的原因。话虽这么说,这个实现非常有效地使用了缓存,并且是正确的多线程。

我还发现 BLAS 代码在 Linux 上比在我的机器上 Windows 更快(默认包来自 PIP 和相同的 Numpy 版本 1.20.3)。因此,method2method3 之间的差距更小,但后者仍然快得多。