具有共享非收缩轴的 numpy einsum/tensordot
numpy einsum/tensordot with shared non-contracted axis
假设我有两个数组:
import numpy as np
a = np.random.randn(32, 6, 6, 20, 64, 3, 3)
b = np.random.randn(20, 128, 64, 3, 3)
并希望对最后 3 个轴求和,并保留共享轴。输出维度应该是(32,6,6,20,128)
。请注意,此处带有 20 的轴在 a
和 b
中共享。我们称此轴为“组”轴。
这个任务我有两种方法:
第一个只是简单的 einsum
:
def method1(a, b):
return np.einsum('NHWgihw, goihw -> NHWgo', a, b, optimize=True) # output shape:(32,6,6,20,128)
在第二种方法中,我遍历组维度并使用 einsum
/tensordot
计算每个组维度的结果,然后堆叠结果:
def method2(a, b):
result = []
for g in range(b.shape[0]): # loop through each group dimension
# result.append(np.tensordot(a[..., g, :, :, :], b[g, ...], axes=((-3,-2,-1),(-3,-2,-1))))
result.append(np.einsum('NHWihw, oihw -> NHWo', a[..., g, :, :, :], b[g, ...], optimize=True)) # output shape:(32,6,6,128)
return np.stack(result, axis=-2) # output shape:(32,6,6,20,128)
这是我的 jupyter notebook 中两种方法的时间安排:
我们可以看到带循环的第二种方法比第一种方法更快。
我的问题是:
- method1怎么这么慢?它不会计算更多的东西。
- 有没有不使用循环更高效的方法? (我有点不愿意使用循环,因为它们在 python 中很慢)
感谢您的帮助!
正如@Murali 在评论中指出的那样,method1
效率不高,因为它没有成功使用 BLAS 调用而不是 method2
确实如此。事实上,np.einsum
在 method1
中非常好,因为它按顺序计算结果,而 method2
在 并行 中主要是 运行s,这要归功于OpenBLAS(Numpy 在大多数机器上使用)。也就是说,method2
是 sub-optimal,因为它没有完全使用可用内核(部分计算是按顺序完成的)并且似乎没有使用缓存有效率的。在我的 6 核机器上,它几乎没有使用所有内核的 50%。
实施速度更快
加速此计算的一个解决方案是为此编写 highly-optimized Numba 并行代码。
首先,semi-naive 实现是使用许多 for 循环来计算爱因斯坦求和并重塑 input/output 数组,以便 Numba 可以更好地优化代码(例如展开,使用SIMD instructions)。这是结果:
@nb.njit('float64[:,:,:,:,::1](float64[:,:,:,:,:,:,::1], float64[:,:,:,:,::1])')
def compute(a, b):
sN, sH, sW, sg, si, sh, sw = a.shape
so = b.shape[1]
assert b.shape == (sg, so, si, sh, sw)
ra = a.reshape(sN*sH*sW, sg, si*sh*sw)
rb = b.reshape(sg, so, si*sh*sw)
out = np.empty((sN*sH*sW, sg, so), dtype=np.float64)
for NHW in range(sN*sH*sW):
for g in range(sg):
for o in range(so):
s = 0.0
# Reduction
for ihw in range(si*sh*sw):
s += ra[NHW, g, ihw] * rb[g, o, ihw]
out[NHW, g, o] = s
return out.reshape((sN, sH, sW, sg, so))
请注意,假定输入数组是连续的。如果不是这种情况,请考虑执行复制(与计算相比成本低)。
虽然上面的代码有效,但远非高效。以下是一些可以执行的改进:
- 运行 最外层
NHW
循环 并行 ;
- 使用 Numba 标志
fastmath=True
。如果输入数据包含 NaN 或 +inf/-inf 等特殊值,则此标志 不安全 。然而,这个标志帮助编译器使用 SIMD 指令生成更快的代码(否则这是不可能的,因为 IEEE-754 floating-point 操作不是关联的);
- 交换基于
NHW
的循环和基于 g
的循环会产生更好的性能,因为它改进了 cache-locality(rb
更有可能适合 last-level 主流 CPU 的缓存,否则它可能会从 RAM 中获取);
- 利用register blocking让处理器更好的SIMD计算单元饱和,降低内存层级的压力;
- 通过拆分基于
o
的循环来利用 tiling,因此 rb
几乎可以完全从 lower-level 缓存(例如 L1 或 L2)中读取。
除最后一项外,所有这些改进都在以下代码中实现:
@nb.njit('float64[:,:,:,:,::1](float64[:,:,:,:,:,:,::1], float64[:,:,:,:,::1])', parallel=True, fastmath=True)
def method3(a, b):
sN, sH, sW, sg, si, sh, sw = a.shape
so = b.shape[1]
assert b.shape == (sg, so, si, sh, sw)
ra = a.reshape(sN*sH*sW, sg, si*sh*sw)
rb = b.reshape(sg, so, si*sh*sw)
out = np.zeros((sN*sH*sW, sg, so), dtype=np.float64)
for g in range(sg):
for k in nb.prange((sN*sH*sW)//2):
NHW = k*2
so_vect_max = (so // 4) * 4
for o in range(0, so_vect_max, 4):
s00 = s01 = s02 = s03 = s10 = s11 = s12 = s13 = 0.0
# Useful since Numba does not optimize well the following loop otherwise
ra_row0 = ra[NHW+0, g, :]
ra_row1 = ra[NHW+1, g, :]
rb_row0 = rb[g, o+0, :]
rb_row1 = rb[g, o+1, :]
rb_row2 = rb[g, o+2, :]
rb_row3 = rb[g, o+3, :]
# Highly-optimized reduction using register blocking
for ihw in range(si*sh*sw):
ra_0 = ra_row0[ihw]
ra_1 = ra_row1[ihw]
rb_0 = rb_row0[ihw]
rb_1 = rb_row1[ihw]
rb_2 = rb_row2[ihw]
rb_3 = rb_row3[ihw]
s00 += ra_0 * rb_0; s01 += ra_0 * rb_1
s02 += ra_0 * rb_2; s03 += ra_0 * rb_3
s10 += ra_1 * rb_0; s11 += ra_1 * rb_1
s12 += ra_1 * rb_2; s13 += ra_1 * rb_3
out[NHW+0, g, o+0] = s00; out[NHW+0, g, o+1] = s01
out[NHW+0, g, o+2] = s02; out[NHW+0, g, o+3] = s03
out[NHW+1, g, o+0] = s10; out[NHW+1, g, o+1] = s11
out[NHW+1, g, o+2] = s12; out[NHW+1, g, o+3] = s13
# Remaining part for `o`
for o in range(so_vect_max, so):
for ihw in range(si*sh*sw):
out[NHW, g, o] += ra[NHW, g, ihw] * rb[g, o, ihw]
out[NHW+1, g, o] += ra[NHW+1, g, ihw] * rb[g, o, ihw]
# Remaining part for `k`
if (sN*sH*sW) % 2 == 1:
k = sN*sH*sW - 1
for o in range(so):
for ihw in range(si*sh*sw):
out[k, g, o] += ra[k, g, ihw] * rb[g, o, ihw]
return out.reshape((sN, sH, sW, sg, so))
此代码更加复杂和丑陋,但也更加高效。我没有实施平铺优化,因为它会使代码的可读性更差。但是,它应该会在 many-core 处理器(尤其是具有较小 L2/L3 缓存的处理器)上产生明显更快的代码。
性能结果
以下是我的 i5-9600KF 6 核处理器的性能结果:
method1: 816 ms
method2: 104 ms
method3: 40 ms
Theoretical optimal: 9 ms (optimistic lower bound)
该代码比 method2
快约 2.7。由于最佳时间比 method3
.
快约 4 倍,因此还有改进的余地
Numba 无法生成快速代码的主要原因是底层 JIT 无法有效地向量化循环。实施平铺策略应该会略微提高执行时间,使其非常接近最佳时间。平铺策略对于更大的阵列至关重要。如果 so
大得多,则尤其如此。
如果您想要更快的实现,您当然需要直接使用 SIMD 指令(遗憾的是不可移植)或 SIMD 库(例如 XSIMD)编写 C/C++ 本机代码。
如果您想要更快的实现,那么您需要使用更快的硬件(具有更多内核)或更专用的硬件。 Server-based GPU(即不是个人电脑)不应该能够加速很多这样的计算,因为你的输入很小,显然 compute-bound并大量使用 FMA floating-point 操作。第一步是尝试 cupy.einsum
.
引擎盖下:low-level 分析
为了理解为什么method1
不快,我检查了执行的代码。这是主循环:
1a0:┌─→; Part of the reduction (see below)
│ movapd xmm0,XMMWORD PTR [rdi-0x1000]
│
│ ; Decrement the number of loop cycle
│ sub r9,0x8
│
│ ; Prefetch items so to reduce the impact
│ ; of the latency of reading from the RAM.
│ prefetcht0 BYTE PTR [r8]
│ prefetcht0 BYTE PTR [rdi]
│
│ ; Part of the reduction (see below)
│ mulpd xmm0,XMMWORD PTR [r8-0x1000]
│
│ ; Increment iterator for the two arrays
│ add rdi,0x40
│ add r8,0x40
│
│ ; Main computational part:
│ ; reduction using add+mul SSE2 instructions
│ addpd xmm1,xmm0 <--- Slow
│ movapd xmm0,XMMWORD PTR [rdi-0x1030]
│ mulpd xmm0,XMMWORD PTR [r8-0x1030]
│ addpd xmm1,xmm0 <--- Slow
│ movapd xmm0,XMMWORD PTR [rdi-0x1020]
│ mulpd xmm0,XMMWORD PTR [r8-0x1020]
│ addpd xmm0,xmm1 <--- Slow
│ movapd xmm1,XMMWORD PTR [rdi-0x1010]
│ mulpd xmm1,XMMWORD PTR [r8-0x1010]
│ addpd xmm1,xmm0 <--- Slow
│
│ ; Is the loop over?
│ ; If not, jump to the beginning of the loop.
├──cmp r9,0x7
└──jg 1a0
事实证明,Numpy 使用 SSE2 指令集(在所有 x86-64 处理器上都可用)。然而,我的机器,就像几乎所有相对较新的处理器一样,都支持 AVX 指令集,它可以每条指令一次计算两倍以上的项目。我的机器还支持 fuse-multiply 添加指令 (FMA),在这种情况下速度是原来的两倍。此外,循环显然受 addpd
的限制,它在大部分相同的寄存器中累积结果。处理器无法有效地执行它们,因为 addpd
需要很少的延迟周期,并且在现代 x86-64 处理器上最多可以同时执行两个(这在这里是不可能的,因为只有 1 个指令可以在xmm1
一次)。
这里是method2
主要计算部分的执行代码(dgemm
调用OpenBLAS):
6a40:┌─→vbroadcastsd ymm0,QWORD PTR [rsi-0x60]
│ vbroadcastsd ymm1,QWORD PTR [rsi-0x58]
│ vbroadcastsd ymm2,QWORD PTR [rsi-0x50]
│ vbroadcastsd ymm3,QWORD PTR [rsi-0x48]
│ vfmadd231pd ymm4,ymm0,YMMWORD PTR [rdi-0x80]
│ vfmadd231pd ymm5,ymm1,YMMWORD PTR [rdi-0x60]
│ vbroadcastsd ymm0,QWORD PTR [rsi-0x40]
│ vbroadcastsd ymm1,QWORD PTR [rsi-0x38]
│ vfmadd231pd ymm6,ymm2,YMMWORD PTR [rdi-0x40]
│ vfmadd231pd ymm7,ymm3,YMMWORD PTR [rdi-0x20]
│ vbroadcastsd ymm2,QWORD PTR [rsi-0x30]
│ vbroadcastsd ymm3,QWORD PTR [rsi-0x28]
│ vfmadd231pd ymm4,ymm0,YMMWORD PTR [rdi]
│ vfmadd231pd ymm5,ymm1,YMMWORD PTR [rdi+0x20]
│ vfmadd231pd ymm6,ymm2,YMMWORD PTR [rdi+0x40]
│ vfmadd231pd ymm7,ymm3,YMMWORD PTR [rdi+0x60]
│ add rsi,0x40
│ add rdi,0x100
├──dec rax
└──jne 6a40
这个循环要优化得多:它利用了 AVX 指令集和 FMA 指令集(即。vfmadd231pd
说明)。此外,循环展开得更好,并且没有像 Numpy 代码中那样的 latency/dependency 问题。然而,虽然这个循环是 highly-efficient,但由于在 Numpy 中完成了一些顺序检查并且在 OpenBLAS 中执行了顺序复制,因此核心没有得到有效使用。此外,我不确定在这种情况下循环是否有效地使用了缓存,因为很多 read/writes 是在我机器的 RAM 中执行的。实际上,由于许多缓存未命中,RAM 吞吐量约为 15 GiB/s(超过 35~40 GiB/s),而 method3
的吞吐量为 6 GiB/s(因此完成了更多工作在缓存中)执行速度明显加快。
下面是method3
主要计算部分的执行代码:
.LBB0_5:
vorpd 2880(%rsp), %ymm8, %ymm0
vpcmpeqd %ymm1, %ymm1, %ymm1
vgatherqpd %ymm1, (%rsi,%ymm0,8), %ymm2
vmovupd %ymm2, 3040(%rsp)
vorpd 2848(%rsp), %ymm8, %ymm1
vpcmpeqd %ymm2, %ymm2, %ymm2
vgatherqpd %ymm2, (%rsi,%ymm1,8), %ymm3
vmovupd %ymm3, 3104(%rsp)
vorpd 2912(%rsp), %ymm8, %ymm2
vpcmpeqd %ymm3, %ymm3, %ymm3
vgatherqpd %ymm3, (%rsi,%ymm2,8), %ymm4
vmovupd %ymm4, 3136(%rsp)
vorpd 2816(%rsp), %ymm8, %ymm3
vpcmpeqd %ymm4, %ymm4, %ymm4
vgatherqpd %ymm4, (%rsi,%ymm3,8), %ymm5
vmovupd %ymm5, 3808(%rsp)
vorpd 2784(%rsp), %ymm8, %ymm9
vpcmpeqd %ymm4, %ymm4, %ymm4
vgatherqpd %ymm4, (%rsi,%ymm9,8), %ymm5
vmovupd %ymm5, 3840(%rsp)
vorpd 2752(%rsp), %ymm8, %ymm10
vpcmpeqd %ymm4, %ymm4, %ymm4
vgatherqpd %ymm4, (%rsi,%ymm10,8), %ymm5
vmovupd %ymm5, 3872(%rsp)
vpaddq 2944(%rsp), %ymm8, %ymm4
vorpd 2720(%rsp), %ymm8, %ymm11
vpcmpeqd %ymm13, %ymm13, %ymm13
vgatherqpd %ymm13, (%rsi,%ymm11,8), %ymm5
vmovupd %ymm5, 3904(%rsp)
vpcmpeqd %ymm13, %ymm13, %ymm13
vgatherqpd %ymm13, (%rdx,%ymm0,8), %ymm5
vmovupd %ymm5, 3552(%rsp)
vpcmpeqd %ymm0, %ymm0, %ymm0
vgatherqpd %ymm0, (%rdx,%ymm1,8), %ymm5
vmovupd %ymm5, 3616(%rsp)
vpcmpeqd %ymm0, %ymm0, %ymm0
vgatherqpd %ymm0, (%rdx,%ymm2,8), %ymm1
vmovupd %ymm1, 3648(%rsp)
vpcmpeqd %ymm0, %ymm0, %ymm0
vgatherqpd %ymm0, (%rdx,%ymm3,8), %ymm1
vmovupd %ymm1, 3680(%rsp)
vpcmpeqd %ymm0, %ymm0, %ymm0
vgatherqpd %ymm0, (%rdx,%ymm9,8), %ymm1
vmovupd %ymm1, 3712(%rsp)
vpcmpeqd %ymm0, %ymm0, %ymm0
vgatherqpd %ymm0, (%rdx,%ymm10,8), %ymm1
vmovupd %ymm1, 3744(%rsp)
vpcmpeqd %ymm0, %ymm0, %ymm0
vgatherqpd %ymm0, (%rdx,%ymm11,8), %ymm1
vmovupd %ymm1, 3776(%rsp)
vpcmpeqd %ymm0, %ymm0, %ymm0
vgatherqpd %ymm0, (%rsi,%ymm4,8), %ymm6
vpcmpeqd %ymm0, %ymm0, %ymm0
vgatherqpd %ymm0, (%rdx,%ymm4,8), %ymm3
vpaddq 2688(%rsp), %ymm8, %ymm0
vpcmpeqd %ymm1, %ymm1, %ymm1
vgatherqpd %ymm1, (%rsi,%ymm0,8), %ymm7
vpcmpeqd %ymm1, %ymm1, %ymm1
vgatherqpd %ymm1, (%rdx,%ymm0,8), %ymm4
vmovupd %ymm4, 3360(%rsp)
vpaddq 2656(%rsp), %ymm8, %ymm0
vpcmpeqd %ymm1, %ymm1, %ymm1
vgatherqpd %ymm1, (%rsi,%ymm0,8), %ymm13
vpcmpeqd %ymm1, %ymm1, %ymm1
vgatherqpd %ymm1, (%rdx,%ymm0,8), %ymm4
vmovupd %ymm4, 3392(%rsp)
vpaddq 2624(%rsp), %ymm8, %ymm0
vpcmpeqd %ymm1, %ymm1, %ymm1
vgatherqpd %ymm1, (%rsi,%ymm0,8), %ymm15
vpcmpeqd %ymm1, %ymm1, %ymm1
vgatherqpd %ymm1, (%rdx,%ymm0,8), %ymm4
vmovupd %ymm4, 3424(%rsp)
vpaddq 2592(%rsp), %ymm8, %ymm0
vpcmpeqd %ymm1, %ymm1, %ymm1
vgatherqpd %ymm1, (%rsi,%ymm0,8), %ymm9
vpcmpeqd %ymm1, %ymm1, %ymm1
vgatherqpd %ymm1, (%rdx,%ymm0,8), %ymm4
vmovupd %ymm4, 3456(%rsp)
vpaddq 2560(%rsp), %ymm8, %ymm0
vpcmpeqd %ymm1, %ymm1, %ymm1
vgatherqpd %ymm1, (%rsi,%ymm0,8), %ymm14
vpcmpeqd %ymm1, %ymm1, %ymm1
vgatherqpd %ymm1, (%rdx,%ymm0,8), %ymm4
vmovupd %ymm4, 3488(%rsp)
vpaddq 2528(%rsp), %ymm8, %ymm0
vpcmpeqd %ymm1, %ymm1, %ymm1
vgatherqpd %ymm1, (%rsi,%ymm0,8), %ymm11
vpcmpeqd %ymm1, %ymm1, %ymm1
vgatherqpd %ymm1, (%rdx,%ymm0,8), %ymm4
vmovupd %ymm4, 3520(%rsp)
vpaddq 2496(%rsp), %ymm8, %ymm0
vpcmpeqd %ymm1, %ymm1, %ymm1
vgatherqpd %ymm1, (%rsi,%ymm0,8), %ymm10
vpcmpeqd %ymm1, %ymm1, %ymm1
vgatherqpd %ymm1, (%rdx,%ymm0,8), %ymm4
vmovupd %ymm4, 3584(%rsp)
vpaddq 2464(%rsp), %ymm8, %ymm0
vpcmpeqd %ymm1, %ymm1, %ymm1
vgatherqpd %ymm1, (%rdx,%ymm0,8), %ymm2
vpaddq 2432(%rsp), %ymm8, %ymm0
vpcmpeqd %ymm1, %ymm1, %ymm1
vgatherqpd %ymm1, (%rdx,%ymm0,8), %ymm12
vpaddq 2400(%rsp), %ymm8, %ymm0
vpcmpeqd %ymm1, %ymm1, %ymm1
vgatherqpd %ymm1, (%rdx,%ymm0,8), %ymm4
vmovupd %ymm4, 3168(%rsp)
vpaddq 2368(%rsp), %ymm8, %ymm0
vpcmpeqd %ymm1, %ymm1, %ymm1
vgatherqpd %ymm1, (%rdx,%ymm0,8), %ymm4
vmovupd %ymm4, 3200(%rsp)
vpaddq 2336(%rsp), %ymm8, %ymm0
vpcmpeqd %ymm1, %ymm1, %ymm1
vgatherqpd %ymm1, (%rdx,%ymm0,8), %ymm4
vmovupd %ymm4, 3232(%rsp)
vpaddq 2304(%rsp), %ymm8, %ymm0
vpcmpeqd %ymm1, %ymm1, %ymm1
vgatherqpd %ymm1, (%rdx,%ymm0,8), %ymm4
vmovupd %ymm4, 3264(%rsp)
vpaddq 2272(%rsp), %ymm8, %ymm0
vpcmpeqd %ymm1, %ymm1, %ymm1
vgatherqpd %ymm1, (%rdx,%ymm0,8), %ymm4
vmovupd %ymm4, 3296(%rsp)
vpaddq 2240(%rsp), %ymm8, %ymm0
vpcmpeqd %ymm1, %ymm1, %ymm1
vgatherqpd %ymm1, (%rdx,%ymm0,8), %ymm4
vmovupd %ymm4, 3328(%rsp)
vpaddq 2208(%rsp), %ymm8, %ymm0
vpcmpeqd %ymm1, %ymm1, %ymm1
vgatherqpd %ymm1, (%rdx,%ymm0,8), %ymm4
vpaddq 2176(%rsp), %ymm8, %ymm0
vpcmpeqd %ymm1, %ymm1, %ymm1
vgatherqpd %ymm1, (%rdx,%ymm0,8), %ymm5
vmovupd %ymm5, 2976(%rsp)
vpaddq 2144(%rsp), %ymm8, %ymm0
vpcmpeqd %ymm1, %ymm1, %ymm1
vgatherqpd %ymm1, (%rdx,%ymm0,8), %ymm5
vmovupd %ymm5, 3008(%rsp)
vpaddq 2112(%rsp), %ymm8, %ymm0
vpcmpeqd %ymm1, %ymm1, %ymm1
vgatherqpd %ymm1, (%rdx,%ymm0,8), %ymm5
vmovupd %ymm5, 3072(%rsp)
vpcmpeqd %ymm1, %ymm1, %ymm1
vgatherqpd %ymm1, (%rsi,%ymm8,8), %ymm0
vpcmpeqd %ymm5, %ymm5, %ymm5
vgatherqpd %ymm5, (%rdx,%ymm8,8), %ymm1
vmovupd 768(%rsp), %ymm5
vfmadd231pd %ymm0, %ymm1, %ymm5
vmovupd %ymm5, 768(%rsp)
vmovupd 32(%rsp), %ymm5
vfmadd231pd %ymm0, %ymm3, %ymm5
vmovupd %ymm5, 32(%rsp)
vmovupd 1024(%rsp), %ymm5
vfmadd231pd %ymm0, %ymm2, %ymm5
vmovupd %ymm5, 1024(%rsp)
vmovupd 1280(%rsp), %ymm5
vfmadd231pd %ymm0, %ymm4, %ymm5
vmovupd %ymm5, 1280(%rsp)
vmovupd 1344(%rsp), %ymm0
vfmadd231pd %ymm1, %ymm6, %ymm0
vmovupd %ymm0, 1344(%rsp)
vmovupd 480(%rsp), %ymm0
vfmadd231pd %ymm3, %ymm6, %ymm0
vmovupd %ymm0, 480(%rsp)
vmovupd 1600(%rsp), %ymm0
vfmadd231pd %ymm2, %ymm6, %ymm0
vmovupd %ymm0, 1600(%rsp)
vmovupd 1856(%rsp), %ymm0
vfmadd231pd %ymm4, %ymm6, %ymm0
vmovupd %ymm0, 1856(%rsp)
vpaddq 2080(%rsp), %ymm8, %ymm0
vpcmpeqd %ymm1, %ymm1, %ymm1
vgatherqpd %ymm1, (%rdx,%ymm0,8), %ymm2
vpaddq 2048(%rsp), %ymm8, %ymm0
vpcmpeqd %ymm1, %ymm1, %ymm1
vgatherqpd %ymm1, (%rdx,%ymm0,8), %ymm4
vmovupd 800(%rsp), %ymm0
vmovupd 3552(%rsp), %ymm1
vmovupd 3040(%rsp), %ymm3
vfmadd231pd %ymm3, %ymm1, %ymm0
vmovupd %ymm0, 800(%rsp)
vmovupd 64(%rsp), %ymm0
vmovupd 3360(%rsp), %ymm5
vfmadd231pd %ymm3, %ymm5, %ymm0
vmovupd %ymm0, 64(%rsp)
vmovupd 1056(%rsp), %ymm0
vfmadd231pd %ymm3, %ymm12, %ymm0
vmovupd %ymm0, 1056(%rsp)
vmovupd 288(%rsp), %ymm0
vmovupd 2976(%rsp), %ymm6
vfmadd231pd %ymm3, %ymm6, %ymm0
vmovupd %ymm0, 288(%rsp)
vmovupd 1376(%rsp), %ymm0
vfmadd231pd %ymm1, %ymm7, %ymm0
vmovupd %ymm0, 1376(%rsp)
vmovupd 512(%rsp), %ymm0
vfmadd231pd %ymm5, %ymm7, %ymm0
vmovupd %ymm0, 512(%rsp)
vmovupd 1632(%rsp), %ymm0
vfmadd231pd %ymm12, %ymm7, %ymm0
vmovupd %ymm0, 1632(%rsp)
vmovupd 1888(%rsp), %ymm0
vfmadd231pd %ymm6, %ymm7, %ymm0
vmovupd %ymm0, 1888(%rsp)
vmovupd 832(%rsp), %ymm0
vmovupd 3616(%rsp), %ymm1
vmovupd 3104(%rsp), %ymm6
vfmadd231pd %ymm6, %ymm1, %ymm0
vmovupd %ymm0, 832(%rsp)
vmovupd 96(%rsp), %ymm0
vmovupd 3392(%rsp), %ymm3
vfmadd231pd %ymm6, %ymm3, %ymm0
vmovupd %ymm0, 96(%rsp)
vmovupd 1088(%rsp), %ymm0
vmovupd 3168(%rsp), %ymm5
vfmadd231pd %ymm6, %ymm5, %ymm0
vmovupd %ymm0, 1088(%rsp)
vmovupd 320(%rsp), %ymm0
vmovupd 3008(%rsp), %ymm7
vfmadd231pd %ymm6, %ymm7, %ymm0
vmovupd %ymm0, 320(%rsp)
vmovupd 1408(%rsp), %ymm0
vfmadd231pd %ymm1, %ymm13, %ymm0
vmovupd %ymm0, 1408(%rsp)
vmovupd 544(%rsp), %ymm0
vfmadd231pd %ymm3, %ymm13, %ymm0
vmovupd %ymm0, 544(%rsp)
vmovupd 1664(%rsp), %ymm0
vfmadd231pd %ymm5, %ymm13, %ymm0
vmovupd %ymm0, 1664(%rsp)
vmovupd 1920(%rsp), %ymm0
vfmadd231pd %ymm7, %ymm13, %ymm0
vmovupd %ymm0, 1920(%rsp)
vpaddq 2016(%rsp), %ymm8, %ymm0
vpcmpeqd %ymm1, %ymm1, %ymm1
vgatherqpd %ymm1, (%rdx,%ymm0,8), %ymm3
vmovupd 864(%rsp), %ymm0
vmovupd 3648(%rsp), %ymm1
vmovupd 3136(%rsp), %ymm6
vfmadd231pd %ymm6, %ymm1, %ymm0
vmovupd %ymm0, 864(%rsp)
vmovupd 128(%rsp), %ymm0
vmovupd 3424(%rsp), %ymm5
vfmadd231pd %ymm6, %ymm5, %ymm0
vmovupd %ymm0, 128(%rsp)
vmovupd 1120(%rsp), %ymm0
vmovupd 3200(%rsp), %ymm7
vfmadd231pd %ymm6, %ymm7, %ymm0
vmovupd %ymm0, 1120(%rsp)
vmovupd 352(%rsp), %ymm0
vmovupd 3072(%rsp), %ymm12
vfmadd231pd %ymm6, %ymm12, %ymm0
vmovupd %ymm0, 352(%rsp)
vmovupd 1440(%rsp), %ymm0
vfmadd231pd %ymm1, %ymm15, %ymm0
vmovupd %ymm0, 1440(%rsp)
vmovupd 576(%rsp), %ymm0
vfmadd231pd %ymm5, %ymm15, %ymm0
vmovupd %ymm0, 576(%rsp)
vmovupd 1696(%rsp), %ymm0
vfmadd231pd %ymm7, %ymm15, %ymm0
vmovupd %ymm0, 1696(%rsp)
vmovupd 736(%rsp), %ymm0
vfmadd231pd %ymm12, %ymm15, %ymm0
vmovupd %ymm0, 736(%rsp)
vmovupd 896(%rsp), %ymm0
vmovupd 3808(%rsp), %ymm1
vmovupd 3680(%rsp), %ymm5
vfmadd231pd %ymm1, %ymm5, %ymm0
vmovupd %ymm0, 896(%rsp)
vmovupd 160(%rsp), %ymm0
vmovupd 3456(%rsp), %ymm6
vfmadd231pd %ymm1, %ymm6, %ymm0
vmovupd %ymm0, 160(%rsp)
vmovupd 1152(%rsp), %ymm0
vmovupd 3232(%rsp), %ymm7
vfmadd231pd %ymm1, %ymm7, %ymm0
vmovupd %ymm0, 1152(%rsp)
vmovupd 384(%rsp), %ymm0
vfmadd231pd %ymm1, %ymm2, %ymm0
vmovupd %ymm0, 384(%rsp)
vmovupd 1472(%rsp), %ymm0
vfmadd231pd %ymm5, %ymm9, %ymm0
vmovupd %ymm0, 1472(%rsp)
vmovupd 608(%rsp), %ymm0
vfmadd231pd %ymm6, %ymm9, %ymm0
vmovupd %ymm0, 608(%rsp)
vmovupd 1728(%rsp), %ymm0
vfmadd231pd %ymm7, %ymm9, %ymm0
vmovupd %ymm0, 1728(%rsp)
vmovupd -128(%rsp), %ymm0
vfmadd231pd %ymm2, %ymm9, %ymm0
vmovupd %ymm0, -128(%rsp)
vmovupd 928(%rsp), %ymm0
vmovupd 3840(%rsp), %ymm1
vmovupd 3712(%rsp), %ymm2
vfmadd231pd %ymm1, %ymm2, %ymm0
vmovupd %ymm0, 928(%rsp)
vmovupd 192(%rsp), %ymm0
vmovupd 3488(%rsp), %ymm5
vfmadd231pd %ymm1, %ymm5, %ymm0
vmovupd %ymm0, 192(%rsp)
vmovupd 1184(%rsp), %ymm0
vmovupd 3264(%rsp), %ymm6
vfmadd231pd %ymm1, %ymm6, %ymm0
vmovupd %ymm0, 1184(%rsp)
vmovupd 416(%rsp), %ymm0
vfmadd231pd %ymm1, %ymm4, %ymm0
vmovupd %ymm0, 416(%rsp)
vmovupd 1504(%rsp), %ymm0
vfmadd231pd %ymm2, %ymm14, %ymm0
vmovupd %ymm0, 1504(%rsp)
vmovupd 640(%rsp), %ymm0
vfmadd231pd %ymm5, %ymm14, %ymm0
vmovupd %ymm0, 640(%rsp)
vmovupd 1760(%rsp), %ymm0
vfmadd231pd %ymm6, %ymm14, %ymm0
vmovupd %ymm0, 1760(%rsp)
vmovupd -96(%rsp), %ymm0
vfmadd231pd %ymm4, %ymm14, %ymm0
vmovupd %ymm0, -96(%rsp)
vpaddq 1984(%rsp), %ymm8, %ymm0
vpcmpeqd %ymm1, %ymm1, %ymm1
vgatherqpd %ymm1, (%rdx,%ymm0,8), %ymm2
vmovupd 960(%rsp), %ymm0
vmovupd 3872(%rsp), %ymm1
vmovupd 3744(%rsp), %ymm4
vfmadd231pd %ymm1, %ymm4, %ymm0
vmovupd %ymm0, 960(%rsp)
vmovupd 224(%rsp), %ymm0
vmovupd 3520(%rsp), %ymm5
vfmadd231pd %ymm1, %ymm5, %ymm0
vmovupd %ymm0, 224(%rsp)
vmovupd 1216(%rsp), %ymm0
vmovupd 3296(%rsp), %ymm6
vfmadd231pd %ymm1, %ymm6, %ymm0
vmovupd %ymm0, 1216(%rsp)
vmovupd 448(%rsp), %ymm0
vfmadd231pd %ymm1, %ymm3, %ymm0
vmovupd %ymm0, 448(%rsp)
vmovupd 1536(%rsp), %ymm0
vfmadd231pd %ymm4, %ymm11, %ymm0
vmovupd %ymm0, 1536(%rsp)
vmovupd 672(%rsp), %ymm0
vfmadd231pd %ymm5, %ymm11, %ymm0
vmovupd %ymm0, 672(%rsp)
vmovupd 1792(%rsp), %ymm0
vfmadd231pd %ymm6, %ymm11, %ymm0
vmovupd %ymm0, 1792(%rsp)
vmovupd -64(%rsp), %ymm0
vfmadd231pd %ymm3, %ymm11, %ymm0
vmovupd %ymm0, -64(%rsp)
vmovupd 992(%rsp), %ymm0
vmovupd 3904(%rsp), %ymm1
vmovupd 3776(%rsp), %ymm3
vfmadd231pd %ymm1, %ymm3, %ymm0
vmovupd %ymm0, 992(%rsp)
vmovupd 256(%rsp), %ymm0
vmovupd 3584(%rsp), %ymm4
vfmadd231pd %ymm1, %ymm4, %ymm0
vmovupd %ymm0, 256(%rsp)
vmovupd 1248(%rsp), %ymm0
vmovupd 3328(%rsp), %ymm5
vfmadd231pd %ymm1, %ymm5, %ymm0
vmovupd %ymm0, 1248(%rsp)
vmovupd 1312(%rsp), %ymm0
vfmadd231pd %ymm1, %ymm2, %ymm0
vmovupd %ymm0, 1312(%rsp)
vmovupd 1568(%rsp), %ymm0
vfmadd231pd %ymm3, %ymm10, %ymm0
vmovupd %ymm0, 1568(%rsp)
vmovupd 704(%rsp), %ymm0
vfmadd231pd %ymm4, %ymm10, %ymm0
vmovupd %ymm0, 704(%rsp)
vmovupd 1824(%rsp), %ymm0
vfmadd231pd %ymm5, %ymm10, %ymm0
vmovupd %ymm0, 1824(%rsp)
vmovupd -32(%rsp), %ymm0
vfmadd231pd %ymm2, %ymm10, %ymm0
vmovupd %ymm0, -32(%rsp)
vpaddq 1952(%rsp), %ymm8, %ymm8
addq $-4, %rcx
jne .LBB0_5
循环很大,显然没有正确矢量化:有很多完全无用的指令,内存加载似乎不连续(参见 vgatherqpd
)。 Numba 不会生成好的代码,因为底层 JIT (LLVM-Lite) 无法有效地矢量化代码。事实上,我发现类似的 C++ 代码在 simplified example 上被 Clang 13.0 严重矢量化(GCC 和 ICC 在更复杂的代码上也失败),而 hand-written SIMD 实现工作得更好。它看起来像是优化器的错误,或者至少是错过了优化。这就是 Numba 代码比最优代码慢得多的原因。话虽这么说,这个实现非常有效地使用了缓存,并且是正确的多线程。
我还发现 BLAS 代码在 Linux 上比在我的机器上 Windows 更快(默认包来自 PIP 和相同的 Numpy 版本 1.20.3)。因此,method2
和 method3
之间的差距更小,但后者仍然快得多。
假设我有两个数组:
import numpy as np
a = np.random.randn(32, 6, 6, 20, 64, 3, 3)
b = np.random.randn(20, 128, 64, 3, 3)
并希望对最后 3 个轴求和,并保留共享轴。输出维度应该是(32,6,6,20,128)
。请注意,此处带有 20 的轴在 a
和 b
中共享。我们称此轴为“组”轴。
这个任务我有两种方法:
第一个只是简单的 einsum
:
def method1(a, b):
return np.einsum('NHWgihw, goihw -> NHWgo', a, b, optimize=True) # output shape:(32,6,6,20,128)
在第二种方法中,我遍历组维度并使用 einsum
/tensordot
计算每个组维度的结果,然后堆叠结果:
def method2(a, b):
result = []
for g in range(b.shape[0]): # loop through each group dimension
# result.append(np.tensordot(a[..., g, :, :, :], b[g, ...], axes=((-3,-2,-1),(-3,-2,-1))))
result.append(np.einsum('NHWihw, oihw -> NHWo', a[..., g, :, :, :], b[g, ...], optimize=True)) # output shape:(32,6,6,128)
return np.stack(result, axis=-2) # output shape:(32,6,6,20,128)
这是我的 jupyter notebook 中两种方法的时间安排:
我们可以看到带循环的第二种方法比第一种方法更快。
我的问题是:
- method1怎么这么慢?它不会计算更多的东西。
- 有没有不使用循环更高效的方法? (我有点不愿意使用循环,因为它们在 python 中很慢)
感谢您的帮助!
正如@Murali 在评论中指出的那样,method1
效率不高,因为它没有成功使用 BLAS 调用而不是 method2
确实如此。事实上,np.einsum
在 method1
中非常好,因为它按顺序计算结果,而 method2
在 并行 中主要是 运行s,这要归功于OpenBLAS(Numpy 在大多数机器上使用)。也就是说,method2
是 sub-optimal,因为它没有完全使用可用内核(部分计算是按顺序完成的)并且似乎没有使用缓存有效率的。在我的 6 核机器上,它几乎没有使用所有内核的 50%。
实施速度更快
加速此计算的一个解决方案是为此编写 highly-optimized Numba 并行代码。
首先,semi-naive 实现是使用许多 for 循环来计算爱因斯坦求和并重塑 input/output 数组,以便 Numba 可以更好地优化代码(例如展开,使用SIMD instructions)。这是结果:
@nb.njit('float64[:,:,:,:,::1](float64[:,:,:,:,:,:,::1], float64[:,:,:,:,::1])')
def compute(a, b):
sN, sH, sW, sg, si, sh, sw = a.shape
so = b.shape[1]
assert b.shape == (sg, so, si, sh, sw)
ra = a.reshape(sN*sH*sW, sg, si*sh*sw)
rb = b.reshape(sg, so, si*sh*sw)
out = np.empty((sN*sH*sW, sg, so), dtype=np.float64)
for NHW in range(sN*sH*sW):
for g in range(sg):
for o in range(so):
s = 0.0
# Reduction
for ihw in range(si*sh*sw):
s += ra[NHW, g, ihw] * rb[g, o, ihw]
out[NHW, g, o] = s
return out.reshape((sN, sH, sW, sg, so))
请注意,假定输入数组是连续的。如果不是这种情况,请考虑执行复制(与计算相比成本低)。
虽然上面的代码有效,但远非高效。以下是一些可以执行的改进:
- 运行 最外层
NHW
循环 并行 ; - 使用 Numba 标志
fastmath=True
。如果输入数据包含 NaN 或 +inf/-inf 等特殊值,则此标志 不安全 。然而,这个标志帮助编译器使用 SIMD 指令生成更快的代码(否则这是不可能的,因为 IEEE-754 floating-point 操作不是关联的); - 交换基于
NHW
的循环和基于g
的循环会产生更好的性能,因为它改进了 cache-locality(rb
更有可能适合 last-level 主流 CPU 的缓存,否则它可能会从 RAM 中获取); - 利用register blocking让处理器更好的SIMD计算单元饱和,降低内存层级的压力;
- 通过拆分基于
o
的循环来利用 tiling,因此rb
几乎可以完全从 lower-level 缓存(例如 L1 或 L2)中读取。
除最后一项外,所有这些改进都在以下代码中实现:
@nb.njit('float64[:,:,:,:,::1](float64[:,:,:,:,:,:,::1], float64[:,:,:,:,::1])', parallel=True, fastmath=True)
def method3(a, b):
sN, sH, sW, sg, si, sh, sw = a.shape
so = b.shape[1]
assert b.shape == (sg, so, si, sh, sw)
ra = a.reshape(sN*sH*sW, sg, si*sh*sw)
rb = b.reshape(sg, so, si*sh*sw)
out = np.zeros((sN*sH*sW, sg, so), dtype=np.float64)
for g in range(sg):
for k in nb.prange((sN*sH*sW)//2):
NHW = k*2
so_vect_max = (so // 4) * 4
for o in range(0, so_vect_max, 4):
s00 = s01 = s02 = s03 = s10 = s11 = s12 = s13 = 0.0
# Useful since Numba does not optimize well the following loop otherwise
ra_row0 = ra[NHW+0, g, :]
ra_row1 = ra[NHW+1, g, :]
rb_row0 = rb[g, o+0, :]
rb_row1 = rb[g, o+1, :]
rb_row2 = rb[g, o+2, :]
rb_row3 = rb[g, o+3, :]
# Highly-optimized reduction using register blocking
for ihw in range(si*sh*sw):
ra_0 = ra_row0[ihw]
ra_1 = ra_row1[ihw]
rb_0 = rb_row0[ihw]
rb_1 = rb_row1[ihw]
rb_2 = rb_row2[ihw]
rb_3 = rb_row3[ihw]
s00 += ra_0 * rb_0; s01 += ra_0 * rb_1
s02 += ra_0 * rb_2; s03 += ra_0 * rb_3
s10 += ra_1 * rb_0; s11 += ra_1 * rb_1
s12 += ra_1 * rb_2; s13 += ra_1 * rb_3
out[NHW+0, g, o+0] = s00; out[NHW+0, g, o+1] = s01
out[NHW+0, g, o+2] = s02; out[NHW+0, g, o+3] = s03
out[NHW+1, g, o+0] = s10; out[NHW+1, g, o+1] = s11
out[NHW+1, g, o+2] = s12; out[NHW+1, g, o+3] = s13
# Remaining part for `o`
for o in range(so_vect_max, so):
for ihw in range(si*sh*sw):
out[NHW, g, o] += ra[NHW, g, ihw] * rb[g, o, ihw]
out[NHW+1, g, o] += ra[NHW+1, g, ihw] * rb[g, o, ihw]
# Remaining part for `k`
if (sN*sH*sW) % 2 == 1:
k = sN*sH*sW - 1
for o in range(so):
for ihw in range(si*sh*sw):
out[k, g, o] += ra[k, g, ihw] * rb[g, o, ihw]
return out.reshape((sN, sH, sW, sg, so))
此代码更加复杂和丑陋,但也更加高效。我没有实施平铺优化,因为它会使代码的可读性更差。但是,它应该会在 many-core 处理器(尤其是具有较小 L2/L3 缓存的处理器)上产生明显更快的代码。
性能结果
以下是我的 i5-9600KF 6 核处理器的性能结果:
method1: 816 ms
method2: 104 ms
method3: 40 ms
Theoretical optimal: 9 ms (optimistic lower bound)
该代码比 method2
快约 2.7。由于最佳时间比 method3
.
Numba 无法生成快速代码的主要原因是底层 JIT 无法有效地向量化循环。实施平铺策略应该会略微提高执行时间,使其非常接近最佳时间。平铺策略对于更大的阵列至关重要。如果 so
大得多,则尤其如此。
如果您想要更快的实现,您当然需要直接使用 SIMD 指令(遗憾的是不可移植)或 SIMD 库(例如 XSIMD)编写 C/C++ 本机代码。
如果您想要更快的实现,那么您需要使用更快的硬件(具有更多内核)或更专用的硬件。 Server-based GPU(即不是个人电脑)不应该能够加速很多这样的计算,因为你的输入很小,显然 compute-bound并大量使用 FMA floating-point 操作。第一步是尝试 cupy.einsum
.
引擎盖下:low-level 分析
为了理解为什么method1
不快,我检查了执行的代码。这是主循环:
1a0:┌─→; Part of the reduction (see below)
│ movapd xmm0,XMMWORD PTR [rdi-0x1000]
│
│ ; Decrement the number of loop cycle
│ sub r9,0x8
│
│ ; Prefetch items so to reduce the impact
│ ; of the latency of reading from the RAM.
│ prefetcht0 BYTE PTR [r8]
│ prefetcht0 BYTE PTR [rdi]
│
│ ; Part of the reduction (see below)
│ mulpd xmm0,XMMWORD PTR [r8-0x1000]
│
│ ; Increment iterator for the two arrays
│ add rdi,0x40
│ add r8,0x40
│
│ ; Main computational part:
│ ; reduction using add+mul SSE2 instructions
│ addpd xmm1,xmm0 <--- Slow
│ movapd xmm0,XMMWORD PTR [rdi-0x1030]
│ mulpd xmm0,XMMWORD PTR [r8-0x1030]
│ addpd xmm1,xmm0 <--- Slow
│ movapd xmm0,XMMWORD PTR [rdi-0x1020]
│ mulpd xmm0,XMMWORD PTR [r8-0x1020]
│ addpd xmm0,xmm1 <--- Slow
│ movapd xmm1,XMMWORD PTR [rdi-0x1010]
│ mulpd xmm1,XMMWORD PTR [r8-0x1010]
│ addpd xmm1,xmm0 <--- Slow
│
│ ; Is the loop over?
│ ; If not, jump to the beginning of the loop.
├──cmp r9,0x7
└──jg 1a0
事实证明,Numpy 使用 SSE2 指令集(在所有 x86-64 处理器上都可用)。然而,我的机器,就像几乎所有相对较新的处理器一样,都支持 AVX 指令集,它可以每条指令一次计算两倍以上的项目。我的机器还支持 fuse-multiply 添加指令 (FMA),在这种情况下速度是原来的两倍。此外,循环显然受 addpd
的限制,它在大部分相同的寄存器中累积结果。处理器无法有效地执行它们,因为 addpd
需要很少的延迟周期,并且在现代 x86-64 处理器上最多可以同时执行两个(这在这里是不可能的,因为只有 1 个指令可以在xmm1
一次)。
这里是method2
主要计算部分的执行代码(dgemm
调用OpenBLAS):
6a40:┌─→vbroadcastsd ymm0,QWORD PTR [rsi-0x60]
│ vbroadcastsd ymm1,QWORD PTR [rsi-0x58]
│ vbroadcastsd ymm2,QWORD PTR [rsi-0x50]
│ vbroadcastsd ymm3,QWORD PTR [rsi-0x48]
│ vfmadd231pd ymm4,ymm0,YMMWORD PTR [rdi-0x80]
│ vfmadd231pd ymm5,ymm1,YMMWORD PTR [rdi-0x60]
│ vbroadcastsd ymm0,QWORD PTR [rsi-0x40]
│ vbroadcastsd ymm1,QWORD PTR [rsi-0x38]
│ vfmadd231pd ymm6,ymm2,YMMWORD PTR [rdi-0x40]
│ vfmadd231pd ymm7,ymm3,YMMWORD PTR [rdi-0x20]
│ vbroadcastsd ymm2,QWORD PTR [rsi-0x30]
│ vbroadcastsd ymm3,QWORD PTR [rsi-0x28]
│ vfmadd231pd ymm4,ymm0,YMMWORD PTR [rdi]
│ vfmadd231pd ymm5,ymm1,YMMWORD PTR [rdi+0x20]
│ vfmadd231pd ymm6,ymm2,YMMWORD PTR [rdi+0x40]
│ vfmadd231pd ymm7,ymm3,YMMWORD PTR [rdi+0x60]
│ add rsi,0x40
│ add rdi,0x100
├──dec rax
└──jne 6a40
这个循环要优化得多:它利用了 AVX 指令集和 FMA 指令集(即。vfmadd231pd
说明)。此外,循环展开得更好,并且没有像 Numpy 代码中那样的 latency/dependency 问题。然而,虽然这个循环是 highly-efficient,但由于在 Numpy 中完成了一些顺序检查并且在 OpenBLAS 中执行了顺序复制,因此核心没有得到有效使用。此外,我不确定在这种情况下循环是否有效地使用了缓存,因为很多 read/writes 是在我机器的 RAM 中执行的。实际上,由于许多缓存未命中,RAM 吞吐量约为 15 GiB/s(超过 35~40 GiB/s),而 method3
的吞吐量为 6 GiB/s(因此完成了更多工作在缓存中)执行速度明显加快。
下面是method3
主要计算部分的执行代码:
.LBB0_5:
vorpd 2880(%rsp), %ymm8, %ymm0
vpcmpeqd %ymm1, %ymm1, %ymm1
vgatherqpd %ymm1, (%rsi,%ymm0,8), %ymm2
vmovupd %ymm2, 3040(%rsp)
vorpd 2848(%rsp), %ymm8, %ymm1
vpcmpeqd %ymm2, %ymm2, %ymm2
vgatherqpd %ymm2, (%rsi,%ymm1,8), %ymm3
vmovupd %ymm3, 3104(%rsp)
vorpd 2912(%rsp), %ymm8, %ymm2
vpcmpeqd %ymm3, %ymm3, %ymm3
vgatherqpd %ymm3, (%rsi,%ymm2,8), %ymm4
vmovupd %ymm4, 3136(%rsp)
vorpd 2816(%rsp), %ymm8, %ymm3
vpcmpeqd %ymm4, %ymm4, %ymm4
vgatherqpd %ymm4, (%rsi,%ymm3,8), %ymm5
vmovupd %ymm5, 3808(%rsp)
vorpd 2784(%rsp), %ymm8, %ymm9
vpcmpeqd %ymm4, %ymm4, %ymm4
vgatherqpd %ymm4, (%rsi,%ymm9,8), %ymm5
vmovupd %ymm5, 3840(%rsp)
vorpd 2752(%rsp), %ymm8, %ymm10
vpcmpeqd %ymm4, %ymm4, %ymm4
vgatherqpd %ymm4, (%rsi,%ymm10,8), %ymm5
vmovupd %ymm5, 3872(%rsp)
vpaddq 2944(%rsp), %ymm8, %ymm4
vorpd 2720(%rsp), %ymm8, %ymm11
vpcmpeqd %ymm13, %ymm13, %ymm13
vgatherqpd %ymm13, (%rsi,%ymm11,8), %ymm5
vmovupd %ymm5, 3904(%rsp)
vpcmpeqd %ymm13, %ymm13, %ymm13
vgatherqpd %ymm13, (%rdx,%ymm0,8), %ymm5
vmovupd %ymm5, 3552(%rsp)
vpcmpeqd %ymm0, %ymm0, %ymm0
vgatherqpd %ymm0, (%rdx,%ymm1,8), %ymm5
vmovupd %ymm5, 3616(%rsp)
vpcmpeqd %ymm0, %ymm0, %ymm0
vgatherqpd %ymm0, (%rdx,%ymm2,8), %ymm1
vmovupd %ymm1, 3648(%rsp)
vpcmpeqd %ymm0, %ymm0, %ymm0
vgatherqpd %ymm0, (%rdx,%ymm3,8), %ymm1
vmovupd %ymm1, 3680(%rsp)
vpcmpeqd %ymm0, %ymm0, %ymm0
vgatherqpd %ymm0, (%rdx,%ymm9,8), %ymm1
vmovupd %ymm1, 3712(%rsp)
vpcmpeqd %ymm0, %ymm0, %ymm0
vgatherqpd %ymm0, (%rdx,%ymm10,8), %ymm1
vmovupd %ymm1, 3744(%rsp)
vpcmpeqd %ymm0, %ymm0, %ymm0
vgatherqpd %ymm0, (%rdx,%ymm11,8), %ymm1
vmovupd %ymm1, 3776(%rsp)
vpcmpeqd %ymm0, %ymm0, %ymm0
vgatherqpd %ymm0, (%rsi,%ymm4,8), %ymm6
vpcmpeqd %ymm0, %ymm0, %ymm0
vgatherqpd %ymm0, (%rdx,%ymm4,8), %ymm3
vpaddq 2688(%rsp), %ymm8, %ymm0
vpcmpeqd %ymm1, %ymm1, %ymm1
vgatherqpd %ymm1, (%rsi,%ymm0,8), %ymm7
vpcmpeqd %ymm1, %ymm1, %ymm1
vgatherqpd %ymm1, (%rdx,%ymm0,8), %ymm4
vmovupd %ymm4, 3360(%rsp)
vpaddq 2656(%rsp), %ymm8, %ymm0
vpcmpeqd %ymm1, %ymm1, %ymm1
vgatherqpd %ymm1, (%rsi,%ymm0,8), %ymm13
vpcmpeqd %ymm1, %ymm1, %ymm1
vgatherqpd %ymm1, (%rdx,%ymm0,8), %ymm4
vmovupd %ymm4, 3392(%rsp)
vpaddq 2624(%rsp), %ymm8, %ymm0
vpcmpeqd %ymm1, %ymm1, %ymm1
vgatherqpd %ymm1, (%rsi,%ymm0,8), %ymm15
vpcmpeqd %ymm1, %ymm1, %ymm1
vgatherqpd %ymm1, (%rdx,%ymm0,8), %ymm4
vmovupd %ymm4, 3424(%rsp)
vpaddq 2592(%rsp), %ymm8, %ymm0
vpcmpeqd %ymm1, %ymm1, %ymm1
vgatherqpd %ymm1, (%rsi,%ymm0,8), %ymm9
vpcmpeqd %ymm1, %ymm1, %ymm1
vgatherqpd %ymm1, (%rdx,%ymm0,8), %ymm4
vmovupd %ymm4, 3456(%rsp)
vpaddq 2560(%rsp), %ymm8, %ymm0
vpcmpeqd %ymm1, %ymm1, %ymm1
vgatherqpd %ymm1, (%rsi,%ymm0,8), %ymm14
vpcmpeqd %ymm1, %ymm1, %ymm1
vgatherqpd %ymm1, (%rdx,%ymm0,8), %ymm4
vmovupd %ymm4, 3488(%rsp)
vpaddq 2528(%rsp), %ymm8, %ymm0
vpcmpeqd %ymm1, %ymm1, %ymm1
vgatherqpd %ymm1, (%rsi,%ymm0,8), %ymm11
vpcmpeqd %ymm1, %ymm1, %ymm1
vgatherqpd %ymm1, (%rdx,%ymm0,8), %ymm4
vmovupd %ymm4, 3520(%rsp)
vpaddq 2496(%rsp), %ymm8, %ymm0
vpcmpeqd %ymm1, %ymm1, %ymm1
vgatherqpd %ymm1, (%rsi,%ymm0,8), %ymm10
vpcmpeqd %ymm1, %ymm1, %ymm1
vgatherqpd %ymm1, (%rdx,%ymm0,8), %ymm4
vmovupd %ymm4, 3584(%rsp)
vpaddq 2464(%rsp), %ymm8, %ymm0
vpcmpeqd %ymm1, %ymm1, %ymm1
vgatherqpd %ymm1, (%rdx,%ymm0,8), %ymm2
vpaddq 2432(%rsp), %ymm8, %ymm0
vpcmpeqd %ymm1, %ymm1, %ymm1
vgatherqpd %ymm1, (%rdx,%ymm0,8), %ymm12
vpaddq 2400(%rsp), %ymm8, %ymm0
vpcmpeqd %ymm1, %ymm1, %ymm1
vgatherqpd %ymm1, (%rdx,%ymm0,8), %ymm4
vmovupd %ymm4, 3168(%rsp)
vpaddq 2368(%rsp), %ymm8, %ymm0
vpcmpeqd %ymm1, %ymm1, %ymm1
vgatherqpd %ymm1, (%rdx,%ymm0,8), %ymm4
vmovupd %ymm4, 3200(%rsp)
vpaddq 2336(%rsp), %ymm8, %ymm0
vpcmpeqd %ymm1, %ymm1, %ymm1
vgatherqpd %ymm1, (%rdx,%ymm0,8), %ymm4
vmovupd %ymm4, 3232(%rsp)
vpaddq 2304(%rsp), %ymm8, %ymm0
vpcmpeqd %ymm1, %ymm1, %ymm1
vgatherqpd %ymm1, (%rdx,%ymm0,8), %ymm4
vmovupd %ymm4, 3264(%rsp)
vpaddq 2272(%rsp), %ymm8, %ymm0
vpcmpeqd %ymm1, %ymm1, %ymm1
vgatherqpd %ymm1, (%rdx,%ymm0,8), %ymm4
vmovupd %ymm4, 3296(%rsp)
vpaddq 2240(%rsp), %ymm8, %ymm0
vpcmpeqd %ymm1, %ymm1, %ymm1
vgatherqpd %ymm1, (%rdx,%ymm0,8), %ymm4
vmovupd %ymm4, 3328(%rsp)
vpaddq 2208(%rsp), %ymm8, %ymm0
vpcmpeqd %ymm1, %ymm1, %ymm1
vgatherqpd %ymm1, (%rdx,%ymm0,8), %ymm4
vpaddq 2176(%rsp), %ymm8, %ymm0
vpcmpeqd %ymm1, %ymm1, %ymm1
vgatherqpd %ymm1, (%rdx,%ymm0,8), %ymm5
vmovupd %ymm5, 2976(%rsp)
vpaddq 2144(%rsp), %ymm8, %ymm0
vpcmpeqd %ymm1, %ymm1, %ymm1
vgatherqpd %ymm1, (%rdx,%ymm0,8), %ymm5
vmovupd %ymm5, 3008(%rsp)
vpaddq 2112(%rsp), %ymm8, %ymm0
vpcmpeqd %ymm1, %ymm1, %ymm1
vgatherqpd %ymm1, (%rdx,%ymm0,8), %ymm5
vmovupd %ymm5, 3072(%rsp)
vpcmpeqd %ymm1, %ymm1, %ymm1
vgatherqpd %ymm1, (%rsi,%ymm8,8), %ymm0
vpcmpeqd %ymm5, %ymm5, %ymm5
vgatherqpd %ymm5, (%rdx,%ymm8,8), %ymm1
vmovupd 768(%rsp), %ymm5
vfmadd231pd %ymm0, %ymm1, %ymm5
vmovupd %ymm5, 768(%rsp)
vmovupd 32(%rsp), %ymm5
vfmadd231pd %ymm0, %ymm3, %ymm5
vmovupd %ymm5, 32(%rsp)
vmovupd 1024(%rsp), %ymm5
vfmadd231pd %ymm0, %ymm2, %ymm5
vmovupd %ymm5, 1024(%rsp)
vmovupd 1280(%rsp), %ymm5
vfmadd231pd %ymm0, %ymm4, %ymm5
vmovupd %ymm5, 1280(%rsp)
vmovupd 1344(%rsp), %ymm0
vfmadd231pd %ymm1, %ymm6, %ymm0
vmovupd %ymm0, 1344(%rsp)
vmovupd 480(%rsp), %ymm0
vfmadd231pd %ymm3, %ymm6, %ymm0
vmovupd %ymm0, 480(%rsp)
vmovupd 1600(%rsp), %ymm0
vfmadd231pd %ymm2, %ymm6, %ymm0
vmovupd %ymm0, 1600(%rsp)
vmovupd 1856(%rsp), %ymm0
vfmadd231pd %ymm4, %ymm6, %ymm0
vmovupd %ymm0, 1856(%rsp)
vpaddq 2080(%rsp), %ymm8, %ymm0
vpcmpeqd %ymm1, %ymm1, %ymm1
vgatherqpd %ymm1, (%rdx,%ymm0,8), %ymm2
vpaddq 2048(%rsp), %ymm8, %ymm0
vpcmpeqd %ymm1, %ymm1, %ymm1
vgatherqpd %ymm1, (%rdx,%ymm0,8), %ymm4
vmovupd 800(%rsp), %ymm0
vmovupd 3552(%rsp), %ymm1
vmovupd 3040(%rsp), %ymm3
vfmadd231pd %ymm3, %ymm1, %ymm0
vmovupd %ymm0, 800(%rsp)
vmovupd 64(%rsp), %ymm0
vmovupd 3360(%rsp), %ymm5
vfmadd231pd %ymm3, %ymm5, %ymm0
vmovupd %ymm0, 64(%rsp)
vmovupd 1056(%rsp), %ymm0
vfmadd231pd %ymm3, %ymm12, %ymm0
vmovupd %ymm0, 1056(%rsp)
vmovupd 288(%rsp), %ymm0
vmovupd 2976(%rsp), %ymm6
vfmadd231pd %ymm3, %ymm6, %ymm0
vmovupd %ymm0, 288(%rsp)
vmovupd 1376(%rsp), %ymm0
vfmadd231pd %ymm1, %ymm7, %ymm0
vmovupd %ymm0, 1376(%rsp)
vmovupd 512(%rsp), %ymm0
vfmadd231pd %ymm5, %ymm7, %ymm0
vmovupd %ymm0, 512(%rsp)
vmovupd 1632(%rsp), %ymm0
vfmadd231pd %ymm12, %ymm7, %ymm0
vmovupd %ymm0, 1632(%rsp)
vmovupd 1888(%rsp), %ymm0
vfmadd231pd %ymm6, %ymm7, %ymm0
vmovupd %ymm0, 1888(%rsp)
vmovupd 832(%rsp), %ymm0
vmovupd 3616(%rsp), %ymm1
vmovupd 3104(%rsp), %ymm6
vfmadd231pd %ymm6, %ymm1, %ymm0
vmovupd %ymm0, 832(%rsp)
vmovupd 96(%rsp), %ymm0
vmovupd 3392(%rsp), %ymm3
vfmadd231pd %ymm6, %ymm3, %ymm0
vmovupd %ymm0, 96(%rsp)
vmovupd 1088(%rsp), %ymm0
vmovupd 3168(%rsp), %ymm5
vfmadd231pd %ymm6, %ymm5, %ymm0
vmovupd %ymm0, 1088(%rsp)
vmovupd 320(%rsp), %ymm0
vmovupd 3008(%rsp), %ymm7
vfmadd231pd %ymm6, %ymm7, %ymm0
vmovupd %ymm0, 320(%rsp)
vmovupd 1408(%rsp), %ymm0
vfmadd231pd %ymm1, %ymm13, %ymm0
vmovupd %ymm0, 1408(%rsp)
vmovupd 544(%rsp), %ymm0
vfmadd231pd %ymm3, %ymm13, %ymm0
vmovupd %ymm0, 544(%rsp)
vmovupd 1664(%rsp), %ymm0
vfmadd231pd %ymm5, %ymm13, %ymm0
vmovupd %ymm0, 1664(%rsp)
vmovupd 1920(%rsp), %ymm0
vfmadd231pd %ymm7, %ymm13, %ymm0
vmovupd %ymm0, 1920(%rsp)
vpaddq 2016(%rsp), %ymm8, %ymm0
vpcmpeqd %ymm1, %ymm1, %ymm1
vgatherqpd %ymm1, (%rdx,%ymm0,8), %ymm3
vmovupd 864(%rsp), %ymm0
vmovupd 3648(%rsp), %ymm1
vmovupd 3136(%rsp), %ymm6
vfmadd231pd %ymm6, %ymm1, %ymm0
vmovupd %ymm0, 864(%rsp)
vmovupd 128(%rsp), %ymm0
vmovupd 3424(%rsp), %ymm5
vfmadd231pd %ymm6, %ymm5, %ymm0
vmovupd %ymm0, 128(%rsp)
vmovupd 1120(%rsp), %ymm0
vmovupd 3200(%rsp), %ymm7
vfmadd231pd %ymm6, %ymm7, %ymm0
vmovupd %ymm0, 1120(%rsp)
vmovupd 352(%rsp), %ymm0
vmovupd 3072(%rsp), %ymm12
vfmadd231pd %ymm6, %ymm12, %ymm0
vmovupd %ymm0, 352(%rsp)
vmovupd 1440(%rsp), %ymm0
vfmadd231pd %ymm1, %ymm15, %ymm0
vmovupd %ymm0, 1440(%rsp)
vmovupd 576(%rsp), %ymm0
vfmadd231pd %ymm5, %ymm15, %ymm0
vmovupd %ymm0, 576(%rsp)
vmovupd 1696(%rsp), %ymm0
vfmadd231pd %ymm7, %ymm15, %ymm0
vmovupd %ymm0, 1696(%rsp)
vmovupd 736(%rsp), %ymm0
vfmadd231pd %ymm12, %ymm15, %ymm0
vmovupd %ymm0, 736(%rsp)
vmovupd 896(%rsp), %ymm0
vmovupd 3808(%rsp), %ymm1
vmovupd 3680(%rsp), %ymm5
vfmadd231pd %ymm1, %ymm5, %ymm0
vmovupd %ymm0, 896(%rsp)
vmovupd 160(%rsp), %ymm0
vmovupd 3456(%rsp), %ymm6
vfmadd231pd %ymm1, %ymm6, %ymm0
vmovupd %ymm0, 160(%rsp)
vmovupd 1152(%rsp), %ymm0
vmovupd 3232(%rsp), %ymm7
vfmadd231pd %ymm1, %ymm7, %ymm0
vmovupd %ymm0, 1152(%rsp)
vmovupd 384(%rsp), %ymm0
vfmadd231pd %ymm1, %ymm2, %ymm0
vmovupd %ymm0, 384(%rsp)
vmovupd 1472(%rsp), %ymm0
vfmadd231pd %ymm5, %ymm9, %ymm0
vmovupd %ymm0, 1472(%rsp)
vmovupd 608(%rsp), %ymm0
vfmadd231pd %ymm6, %ymm9, %ymm0
vmovupd %ymm0, 608(%rsp)
vmovupd 1728(%rsp), %ymm0
vfmadd231pd %ymm7, %ymm9, %ymm0
vmovupd %ymm0, 1728(%rsp)
vmovupd -128(%rsp), %ymm0
vfmadd231pd %ymm2, %ymm9, %ymm0
vmovupd %ymm0, -128(%rsp)
vmovupd 928(%rsp), %ymm0
vmovupd 3840(%rsp), %ymm1
vmovupd 3712(%rsp), %ymm2
vfmadd231pd %ymm1, %ymm2, %ymm0
vmovupd %ymm0, 928(%rsp)
vmovupd 192(%rsp), %ymm0
vmovupd 3488(%rsp), %ymm5
vfmadd231pd %ymm1, %ymm5, %ymm0
vmovupd %ymm0, 192(%rsp)
vmovupd 1184(%rsp), %ymm0
vmovupd 3264(%rsp), %ymm6
vfmadd231pd %ymm1, %ymm6, %ymm0
vmovupd %ymm0, 1184(%rsp)
vmovupd 416(%rsp), %ymm0
vfmadd231pd %ymm1, %ymm4, %ymm0
vmovupd %ymm0, 416(%rsp)
vmovupd 1504(%rsp), %ymm0
vfmadd231pd %ymm2, %ymm14, %ymm0
vmovupd %ymm0, 1504(%rsp)
vmovupd 640(%rsp), %ymm0
vfmadd231pd %ymm5, %ymm14, %ymm0
vmovupd %ymm0, 640(%rsp)
vmovupd 1760(%rsp), %ymm0
vfmadd231pd %ymm6, %ymm14, %ymm0
vmovupd %ymm0, 1760(%rsp)
vmovupd -96(%rsp), %ymm0
vfmadd231pd %ymm4, %ymm14, %ymm0
vmovupd %ymm0, -96(%rsp)
vpaddq 1984(%rsp), %ymm8, %ymm0
vpcmpeqd %ymm1, %ymm1, %ymm1
vgatherqpd %ymm1, (%rdx,%ymm0,8), %ymm2
vmovupd 960(%rsp), %ymm0
vmovupd 3872(%rsp), %ymm1
vmovupd 3744(%rsp), %ymm4
vfmadd231pd %ymm1, %ymm4, %ymm0
vmovupd %ymm0, 960(%rsp)
vmovupd 224(%rsp), %ymm0
vmovupd 3520(%rsp), %ymm5
vfmadd231pd %ymm1, %ymm5, %ymm0
vmovupd %ymm0, 224(%rsp)
vmovupd 1216(%rsp), %ymm0
vmovupd 3296(%rsp), %ymm6
vfmadd231pd %ymm1, %ymm6, %ymm0
vmovupd %ymm0, 1216(%rsp)
vmovupd 448(%rsp), %ymm0
vfmadd231pd %ymm1, %ymm3, %ymm0
vmovupd %ymm0, 448(%rsp)
vmovupd 1536(%rsp), %ymm0
vfmadd231pd %ymm4, %ymm11, %ymm0
vmovupd %ymm0, 1536(%rsp)
vmovupd 672(%rsp), %ymm0
vfmadd231pd %ymm5, %ymm11, %ymm0
vmovupd %ymm0, 672(%rsp)
vmovupd 1792(%rsp), %ymm0
vfmadd231pd %ymm6, %ymm11, %ymm0
vmovupd %ymm0, 1792(%rsp)
vmovupd -64(%rsp), %ymm0
vfmadd231pd %ymm3, %ymm11, %ymm0
vmovupd %ymm0, -64(%rsp)
vmovupd 992(%rsp), %ymm0
vmovupd 3904(%rsp), %ymm1
vmovupd 3776(%rsp), %ymm3
vfmadd231pd %ymm1, %ymm3, %ymm0
vmovupd %ymm0, 992(%rsp)
vmovupd 256(%rsp), %ymm0
vmovupd 3584(%rsp), %ymm4
vfmadd231pd %ymm1, %ymm4, %ymm0
vmovupd %ymm0, 256(%rsp)
vmovupd 1248(%rsp), %ymm0
vmovupd 3328(%rsp), %ymm5
vfmadd231pd %ymm1, %ymm5, %ymm0
vmovupd %ymm0, 1248(%rsp)
vmovupd 1312(%rsp), %ymm0
vfmadd231pd %ymm1, %ymm2, %ymm0
vmovupd %ymm0, 1312(%rsp)
vmovupd 1568(%rsp), %ymm0
vfmadd231pd %ymm3, %ymm10, %ymm0
vmovupd %ymm0, 1568(%rsp)
vmovupd 704(%rsp), %ymm0
vfmadd231pd %ymm4, %ymm10, %ymm0
vmovupd %ymm0, 704(%rsp)
vmovupd 1824(%rsp), %ymm0
vfmadd231pd %ymm5, %ymm10, %ymm0
vmovupd %ymm0, 1824(%rsp)
vmovupd -32(%rsp), %ymm0
vfmadd231pd %ymm2, %ymm10, %ymm0
vmovupd %ymm0, -32(%rsp)
vpaddq 1952(%rsp), %ymm8, %ymm8
addq $-4, %rcx
jne .LBB0_5
循环很大,显然没有正确矢量化:有很多完全无用的指令,内存加载似乎不连续(参见 vgatherqpd
)。 Numba 不会生成好的代码,因为底层 JIT (LLVM-Lite) 无法有效地矢量化代码。事实上,我发现类似的 C++ 代码在 simplified example 上被 Clang 13.0 严重矢量化(GCC 和 ICC 在更复杂的代码上也失败),而 hand-written SIMD 实现工作得更好。它看起来像是优化器的错误,或者至少是错过了优化。这就是 Numba 代码比最优代码慢得多的原因。话虽这么说,这个实现非常有效地使用了缓存,并且是正确的多线程。
我还发现 BLAS 代码在 Linux 上比在我的机器上 Windows 更快(默认包来自 PIP 和相同的 Numpy 版本 1.20.3)。因此,method2
和 method3
之间的差距更小,但后者仍然快得多。