使用 AVX2 实现的 GEMM 内核比 Zen 2 CPU 上的 AVX2/FMA 更快
GEMM kernel implemented using AVX2 is faster than AVX2/FMA on a Zen 2 CPU
我试过加速玩具 GEMM 的实施。我处理 32x32 双打块,我需要一个优化的 MM 内核。我可以访问 AVX2 和 FMA。
我在下面定义了两个代码(在 ASM 中,我为格式的粗糙表示歉意),一个是利用 AVX2 功能,另一个使用 FMA。
在不进行微基准测试的情况下,我想尝试理解(理论上)为什么 AVX2 实现比 FMA 版本快 1.11 倍。以及如何改进这两个版本。
下面的代码适用于 3000x3000 MM 的双打,内核是使用经典的、朴素的 MM 和互换的最深循环实现的。我正在使用 Ryzen 3700x/Zen 2 作为开发 CPU。
我没有尝试积极展开,担心 CPU 可能 运行 超出物理寄存器。
AVX2 32x32 MM 内核:
Block 82:
imul r12, r15, 0xbb8
mov rax, r11
mov r13d, 0x0
vmovupd ymm0, ymmword ptr [rdi+r12*8]
vmovupd ymm1, ymmword ptr [rdi+r12*8+0x20]
vmovupd ymm2, ymmword ptr [rdi+r12*8+0x40]
vmovupd ymm3, ymmword ptr [rdi+r12*8+0x60]
vmovupd ymm4, ymmword ptr [rdi+r12*8+0x80]
vmovupd ymm5, ymmword ptr [rdi+r12*8+0xa0]
vmovupd ymm6, ymmword ptr [rdi+r12*8+0xc0]
vmovupd ymm7, ymmword ptr [rdi+r12*8+0xe0]
lea r14, ptr [r12+0x4]
nop dword ptr [rax+rax*1], eax
Block 83:
vbroadcastsd ymm8, qword ptr [rcx+r13*8]
inc r13
vmulpd ymm10, ymm8, ymmword ptr [rax-0xa0]
vmulpd ymm11, ymm8, ymmword ptr [rax-0x80]
vmulpd ymm9, ymm8, ymmword ptr [rax-0xe0]
vmulpd ymm12, ymm8, ymmword ptr [rax-0xc0]
vaddpd ymm2, ymm10, ymm2
vmulpd ymm10, ymm8, ymmword ptr [rax-0x60]
vaddpd ymm3, ymm11, ymm3
vmulpd ymm11, ymm8, ymmword ptr [rax-0x40]
vaddpd ymm0, ymm9, ymm0
vaddpd ymm1, ymm12, ymm1
vaddpd ymm4, ymm10, ymm4
vmulpd ymm10, ymm8, ymmword ptr [rax-0x20]
vmulpd ymm8, ymm8, ymmword ptr [rax]
vaddpd ymm5, ymm11, ymm5
add rax, 0x5dc0
vaddpd ymm6, ymm10, ymm6
vaddpd ymm7, ymm8, ymm7
cmp r13, 0x20
jnz 0x140004530 <Block 83>
Block 84:
inc r15
add rcx, 0x5dc0
vmovupd ymmword ptr [rdi+r12*8], ymm0
vmovupd ymmword ptr [rdi+r14*8], ymm1
vmovupd ymmword ptr [rdi+r12*8+0x40], ymm2
vmovupd ymmword ptr [rdi+r12*8+0x60], ymm3
vmovupd ymmword ptr [rdi+r12*8+0x80], ymm4
vmovupd ymmword ptr [rdi+r12*8+0xa0], ymm5
vmovupd ymmword ptr [rdi+r12*8+0xc0], ymm6
vmovupd ymmword ptr [rdi+r12*8+0xe0], ymm7
cmp r15, 0x20
jnz 0x1400044d0 <Block 82>
AVX2/FMA 32x32 MM内核:
Block 80:
imul r12, r15, 0xbb8
mov rax, r11
mov r13d, 0x0
vmovupd ymm0, ymmword ptr [rdi+r12*8]
vmovupd ymm1, ymmword ptr [rdi+r12*8+0x20]
vmovupd ymm2, ymmword ptr [rdi+r12*8+0x40]
vmovupd ymm3, ymmword ptr [rdi+r12*8+0x60]
vmovupd ymm4, ymmword ptr [rdi+r12*8+0x80]
vmovupd ymm5, ymmword ptr [rdi+r12*8+0xa0]
vmovupd ymm6, ymmword ptr [rdi+r12*8+0xc0]
vmovupd ymm7, ymmword ptr [rdi+r12*8+0xe0]
lea r14, ptr [r12+0x4]
nop dword ptr [rax+rax*1], eax
Block 81:
vbroadcastsd ymm8, qword ptr [rcx+r13*8]
inc r13
vfmadd231pd ymm0, ymm8, ymmword ptr [rax-0xe0]
vfmadd231pd ymm1, ymm8, ymmword ptr [rax-0xc0]
vfmadd231pd ymm2, ymm8, ymmword ptr [rax-0xa0]
vfmadd231pd ymm3, ymm8, ymmword ptr [rax-0x80]
vfmadd231pd ymm4, ymm8, ymmword ptr [rax-0x60]
vfmadd231pd ymm5, ymm8, ymmword ptr [rax-0x40]
vfmadd231pd ymm6, ymm8, ymmword ptr [rax-0x20]
vfmadd231pd ymm7, ymm8, ymmword ptr [rax]
add rax, 0x5dc0
cmp r13, 0x20
jnz 0x140004450
Block 82:
inc r15
add rcx, 0x5dc0
vmovupd ymmword ptr [rdi+r12*8], ymm0
vmovupd ymmword ptr [rdi+r14*8], ymm1
vmovupd ymmword ptr [rdi+r12*8+0x40], ymm2
vmovupd ymmword ptr [rdi+r12*8+0x60], ymm3
vmovupd ymmword ptr [rdi+r12*8+0x80], ymm4
vmovupd ymmword ptr [rdi+r12*8+0xa0], ymm5
vmovupd ymmword ptr [rdi+r12*8+0xc0], ymm6
vmovupd ymmword ptr [rdi+r12*8+0xe0], ymm7
cmp r15, 0x20
jnz 0x1400043f0 <Block 80>
Zen2 vaddpd
有 3 个周期延迟,vfma...pd
有 5 个周期延迟。 (https://uops.info/).
具有 8 个累加器的代码具有足够的 ILP,您期望每个时钟接近两个 FMA,大约每 5 个时钟 8 个(如果没有其他瓶颈),这比理论值的 10/5 略小最大值
vaddpd
和 vmulpd
实际上 运行 在 Zen2 上的 不同 端口(与英特尔不同),端口 FP2/3 和 FP0/1,所以理论上可以维持2/clock vaddpd
和 vmulpd
。由于循环承载依赖的延迟更短,如果调度不让一个 dep 链落后,8 个累加器足以隐藏 vaddpd
延迟。 (但至少乘法不会从中窃取周期。)
Zen2 的前端有 5 个指令宽(如果有任何多 uop 指令,则为 6 uops),并且它可以将内存源指令解码为单个 uop。所以它很可能在每次乘法和非 FMA 版本中做 2/clock。
如果您可以展开 10 或 12,这可能会隐藏足够的 FMA 延迟并使其与非 FMA 版本相同,但功耗更低且更适合 SMT 编码 运行ning另一个逻辑核心。 (10 = 5 x 2 只是 勉强 就够了,这意味着任何调度缺陷都会在关键路径上的 dep 链上失去进展。请参阅 进行一些测试英特尔。)
(相比之下,英特尔 Skylake 运行s vaddpd/vmulpd 在相同端口上与 vfma...pd 具有相同的延迟,均具有 4c 延迟,0.5c 吞吐量。)
我没有仔细查看您的代码,但 10 个 YMM 向量可能是接触两对缓存行与接触 5 条总行之间的权衡,如果空间预取器试图完成一个缓存行,这可能会更糟对齐对。或者可能还好。 12个YMM向量就是三对,应该没问题。
根据矩阵大小,无序执行可能能够在外循环的单独迭代之间重叠内循环 dep 链,特别是如果循环退出条件可以更快地执行并解决错误预测(如果有一)FP工作仍在进行中。对于相同的工作,使用更少的总 uops 是一个优势,有利于 FMA。
我试过加速玩具 GEMM 的实施。我处理 32x32 双打块,我需要一个优化的 MM 内核。我可以访问 AVX2 和 FMA。
我在下面定义了两个代码(在 ASM 中,我为格式的粗糙表示歉意),一个是利用 AVX2 功能,另一个使用 FMA。
在不进行微基准测试的情况下,我想尝试理解(理论上)为什么 AVX2 实现比 FMA 版本快 1.11 倍。以及如何改进这两个版本。
下面的代码适用于 3000x3000 MM 的双打,内核是使用经典的、朴素的 MM 和互换的最深循环实现的。我正在使用 Ryzen 3700x/Zen 2 作为开发 CPU。
我没有尝试积极展开,担心 CPU 可能 运行 超出物理寄存器。
AVX2 32x32 MM 内核:
Block 82:
imul r12, r15, 0xbb8
mov rax, r11
mov r13d, 0x0
vmovupd ymm0, ymmword ptr [rdi+r12*8]
vmovupd ymm1, ymmword ptr [rdi+r12*8+0x20]
vmovupd ymm2, ymmword ptr [rdi+r12*8+0x40]
vmovupd ymm3, ymmword ptr [rdi+r12*8+0x60]
vmovupd ymm4, ymmword ptr [rdi+r12*8+0x80]
vmovupd ymm5, ymmword ptr [rdi+r12*8+0xa0]
vmovupd ymm6, ymmword ptr [rdi+r12*8+0xc0]
vmovupd ymm7, ymmword ptr [rdi+r12*8+0xe0]
lea r14, ptr [r12+0x4]
nop dword ptr [rax+rax*1], eax
Block 83:
vbroadcastsd ymm8, qword ptr [rcx+r13*8]
inc r13
vmulpd ymm10, ymm8, ymmword ptr [rax-0xa0]
vmulpd ymm11, ymm8, ymmword ptr [rax-0x80]
vmulpd ymm9, ymm8, ymmword ptr [rax-0xe0]
vmulpd ymm12, ymm8, ymmword ptr [rax-0xc0]
vaddpd ymm2, ymm10, ymm2
vmulpd ymm10, ymm8, ymmword ptr [rax-0x60]
vaddpd ymm3, ymm11, ymm3
vmulpd ymm11, ymm8, ymmword ptr [rax-0x40]
vaddpd ymm0, ymm9, ymm0
vaddpd ymm1, ymm12, ymm1
vaddpd ymm4, ymm10, ymm4
vmulpd ymm10, ymm8, ymmword ptr [rax-0x20]
vmulpd ymm8, ymm8, ymmword ptr [rax]
vaddpd ymm5, ymm11, ymm5
add rax, 0x5dc0
vaddpd ymm6, ymm10, ymm6
vaddpd ymm7, ymm8, ymm7
cmp r13, 0x20
jnz 0x140004530 <Block 83>
Block 84:
inc r15
add rcx, 0x5dc0
vmovupd ymmword ptr [rdi+r12*8], ymm0
vmovupd ymmword ptr [rdi+r14*8], ymm1
vmovupd ymmword ptr [rdi+r12*8+0x40], ymm2
vmovupd ymmword ptr [rdi+r12*8+0x60], ymm3
vmovupd ymmword ptr [rdi+r12*8+0x80], ymm4
vmovupd ymmword ptr [rdi+r12*8+0xa0], ymm5
vmovupd ymmword ptr [rdi+r12*8+0xc0], ymm6
vmovupd ymmword ptr [rdi+r12*8+0xe0], ymm7
cmp r15, 0x20
jnz 0x1400044d0 <Block 82>
AVX2/FMA 32x32 MM内核:
Block 80:
imul r12, r15, 0xbb8
mov rax, r11
mov r13d, 0x0
vmovupd ymm0, ymmword ptr [rdi+r12*8]
vmovupd ymm1, ymmword ptr [rdi+r12*8+0x20]
vmovupd ymm2, ymmword ptr [rdi+r12*8+0x40]
vmovupd ymm3, ymmword ptr [rdi+r12*8+0x60]
vmovupd ymm4, ymmword ptr [rdi+r12*8+0x80]
vmovupd ymm5, ymmword ptr [rdi+r12*8+0xa0]
vmovupd ymm6, ymmword ptr [rdi+r12*8+0xc0]
vmovupd ymm7, ymmword ptr [rdi+r12*8+0xe0]
lea r14, ptr [r12+0x4]
nop dword ptr [rax+rax*1], eax
Block 81:
vbroadcastsd ymm8, qword ptr [rcx+r13*8]
inc r13
vfmadd231pd ymm0, ymm8, ymmword ptr [rax-0xe0]
vfmadd231pd ymm1, ymm8, ymmword ptr [rax-0xc0]
vfmadd231pd ymm2, ymm8, ymmword ptr [rax-0xa0]
vfmadd231pd ymm3, ymm8, ymmword ptr [rax-0x80]
vfmadd231pd ymm4, ymm8, ymmword ptr [rax-0x60]
vfmadd231pd ymm5, ymm8, ymmword ptr [rax-0x40]
vfmadd231pd ymm6, ymm8, ymmword ptr [rax-0x20]
vfmadd231pd ymm7, ymm8, ymmword ptr [rax]
add rax, 0x5dc0
cmp r13, 0x20
jnz 0x140004450
Block 82:
inc r15
add rcx, 0x5dc0
vmovupd ymmword ptr [rdi+r12*8], ymm0
vmovupd ymmword ptr [rdi+r14*8], ymm1
vmovupd ymmword ptr [rdi+r12*8+0x40], ymm2
vmovupd ymmword ptr [rdi+r12*8+0x60], ymm3
vmovupd ymmword ptr [rdi+r12*8+0x80], ymm4
vmovupd ymmword ptr [rdi+r12*8+0xa0], ymm5
vmovupd ymmword ptr [rdi+r12*8+0xc0], ymm6
vmovupd ymmword ptr [rdi+r12*8+0xe0], ymm7
cmp r15, 0x20
jnz 0x1400043f0 <Block 80>
Zen2 vaddpd
有 3 个周期延迟,vfma...pd
有 5 个周期延迟。 (https://uops.info/).
具有 8 个累加器的代码具有足够的 ILP,您期望每个时钟接近两个 FMA,大约每 5 个时钟 8 个(如果没有其他瓶颈),这比理论值的 10/5 略小最大值
vaddpd
和 vmulpd
实际上 运行 在 Zen2 上的 不同 端口(与英特尔不同),端口 FP2/3 和 FP0/1,所以理论上可以维持2/clock vaddpd
和 vmulpd
。由于循环承载依赖的延迟更短,如果调度不让一个 dep 链落后,8 个累加器足以隐藏 vaddpd
延迟。 (但至少乘法不会从中窃取周期。)
Zen2 的前端有 5 个指令宽(如果有任何多 uop 指令,则为 6 uops),并且它可以将内存源指令解码为单个 uop。所以它很可能在每次乘法和非 FMA 版本中做 2/clock。
如果您可以展开 10 或 12,这可能会隐藏足够的 FMA 延迟并使其与非 FMA 版本相同,但功耗更低且更适合 SMT 编码 运行ning另一个逻辑核心。 (10 = 5 x 2 只是 勉强 就够了,这意味着任何调度缺陷都会在关键路径上的 dep 链上失去进展。请参阅
(相比之下,英特尔 Skylake 运行s vaddpd/vmulpd 在相同端口上与 vfma...pd 具有相同的延迟,均具有 4c 延迟,0.5c 吞吐量。)
我没有仔细查看您的代码,但 10 个 YMM 向量可能是接触两对缓存行与接触 5 条总行之间的权衡,如果空间预取器试图完成一个缓存行,这可能会更糟对齐对。或者可能还好。 12个YMM向量就是三对,应该没问题。
根据矩阵大小,无序执行可能能够在外循环的单独迭代之间重叠内循环 dep 链,特别是如果循环退出条件可以更快地执行并解决错误预测(如果有一)FP工作仍在进行中。对于相同的工作,使用更少的总 uops 是一个优势,有利于 FMA。