使用 AVX2 实现的 GEMM 内核比 Zen 2 CPU 上的 AVX2/FMA 更快

GEMM kernel implemented using AVX2 is faster than AVX2/FMA on a Zen 2 CPU

我试过加速玩具 GEMM 的实施。我处理 32x32 双打块,我需要一个优化的 MM 内核。我可以访问 AVX2 和 FMA。

我在下面定义了两个代码(在 ASM 中,我为格式的粗糙表示歉意),一个是利用 AVX2 功能,另一个使用 FMA。

在不进行微基准测试的情况下,我想尝试理解(理论上)为什么 AVX2 实现比 FMA 版本快 1.11 倍。以及如何改进这两个版本。

下面的代码适用于 3000x3000 MM 的双打,内核是使用经典的、朴素的 MM 和互换的最深循环实现的。我正在使用 Ryzen 3700x/Zen 2 作为开发 CPU。

我没有尝试积极展开,担心 CPU 可能 运行 超出物理寄存器。

AVX2 32x32 MM 内核:

Block 82:
    imul r12, r15, 0xbb8
    mov rax, r11
    mov r13d, 0x0
    vmovupd ymm0, ymmword ptr [rdi+r12*8]
    vmovupd ymm1, ymmword ptr [rdi+r12*8+0x20]
    vmovupd ymm2, ymmword ptr [rdi+r12*8+0x40]
    vmovupd ymm3, ymmword ptr [rdi+r12*8+0x60]
    vmovupd ymm4, ymmword ptr [rdi+r12*8+0x80]
    vmovupd ymm5, ymmword ptr [rdi+r12*8+0xa0]
    vmovupd ymm6, ymmword ptr [rdi+r12*8+0xc0]
    vmovupd ymm7, ymmword ptr [rdi+r12*8+0xe0]
    lea r14, ptr [r12+0x4]
    nop dword ptr [rax+rax*1], eax
Block 83:
    vbroadcastsd ymm8, qword ptr [rcx+r13*8]
    inc r13
    vmulpd ymm10, ymm8, ymmword ptr [rax-0xa0]
    vmulpd ymm11, ymm8, ymmword ptr [rax-0x80]
    vmulpd ymm9, ymm8, ymmword ptr [rax-0xe0]
    vmulpd ymm12, ymm8, ymmword ptr [rax-0xc0]
    vaddpd ymm2, ymm10, ymm2    
    vmulpd ymm10, ymm8, ymmword ptr [rax-0x60]
    vaddpd ymm3, ymm11, ymm3    
    vmulpd ymm11, ymm8, ymmword ptr [rax-0x40]
    vaddpd ymm0, ymm9, ymm0   
    vaddpd ymm1, ymm12, ymm1
    vaddpd ymm4, ymm10, ymm4
    vmulpd ymm10, ymm8, ymmword ptr [rax-0x20]
    vmulpd ymm8, ymm8, ymmword ptr [rax]       
    vaddpd ymm5, ymm11, ymm5    
    add rax, 0x5dc0 
    vaddpd ymm6, ymm10, ymm6
    vaddpd ymm7, ymm8, ymm7 
    cmp r13, 0x20
    jnz 0x140004530 <Block 83>
Block 84:
    inc r15
    add rcx, 0x5dc0
    vmovupd ymmword ptr [rdi+r12*8], ymm0
    vmovupd ymmword ptr [rdi+r14*8], ymm1
    vmovupd ymmword ptr [rdi+r12*8+0x40], ymm2
    vmovupd ymmword ptr [rdi+r12*8+0x60], ymm3
    vmovupd ymmword ptr [rdi+r12*8+0x80], ymm4
    vmovupd ymmword ptr [rdi+r12*8+0xa0], ymm5
    vmovupd ymmword ptr [rdi+r12*8+0xc0], ymm6
    vmovupd ymmword ptr [rdi+r12*8+0xe0], ymm7
    cmp r15, 0x20
    jnz 0x1400044d0 <Block 82>

AVX2/FMA 32x32 MM内核:

Block 80:
    imul r12, r15, 0xbb8
    mov rax, r11
    mov r13d, 0x0
    vmovupd ymm0, ymmword ptr [rdi+r12*8]
    vmovupd ymm1, ymmword ptr [rdi+r12*8+0x20]
    vmovupd ymm2, ymmword ptr [rdi+r12*8+0x40]
    vmovupd ymm3, ymmword ptr [rdi+r12*8+0x60]
    vmovupd ymm4, ymmword ptr [rdi+r12*8+0x80]
    vmovupd ymm5, ymmword ptr [rdi+r12*8+0xa0]
    vmovupd ymm6, ymmword ptr [rdi+r12*8+0xc0]
    vmovupd ymm7, ymmword ptr [rdi+r12*8+0xe0]
    lea r14, ptr [r12+0x4]
    nop dword ptr [rax+rax*1], eax
Block 81:
    vbroadcastsd ymm8, qword ptr [rcx+r13*8]
    inc r13
    vfmadd231pd ymm0, ymm8, ymmword ptr [rax-0xe0]
    vfmadd231pd ymm1, ymm8, ymmword ptr [rax-0xc0]
    vfmadd231pd ymm2, ymm8, ymmword ptr [rax-0xa0]
    vfmadd231pd ymm3, ymm8, ymmword ptr [rax-0x80]
    vfmadd231pd ymm4, ymm8, ymmword ptr [rax-0x60]
    vfmadd231pd ymm5, ymm8, ymmword ptr [rax-0x40]
    vfmadd231pd ymm6, ymm8, ymmword ptr [rax-0x20]
    vfmadd231pd ymm7, ymm8, ymmword ptr [rax]
    add rax, 0x5dc0 
    cmp r13, 0x20   
    jnz 0x140004450
Block 82:
    inc r15
    add rcx, 0x5dc0
    vmovupd ymmword ptr [rdi+r12*8], ymm0
    vmovupd ymmword ptr [rdi+r14*8], ymm1
    vmovupd ymmword ptr [rdi+r12*8+0x40], ymm2
    vmovupd ymmword ptr [rdi+r12*8+0x60], ymm3
    vmovupd ymmword ptr [rdi+r12*8+0x80], ymm4
    vmovupd ymmword ptr [rdi+r12*8+0xa0], ymm5
    vmovupd ymmword ptr [rdi+r12*8+0xc0], ymm6
    vmovupd ymmword ptr [rdi+r12*8+0xe0], ymm7
    cmp r15, 0x20
    jnz 0x1400043f0 <Block 80>

Zen2 vaddpd 有 3 个周期延迟,vfma...pd 有 5 个周期延迟。 (https://uops.info/).

具有 8 个累加器的代码具有足够的 ILP,您期望每个时钟接近两个 FMA,大约每 5 个时钟 8 个(如果没有其他瓶颈),这比理论值的 10/5 略小最大值

vaddpdvmulpd 实际上 运行 在 Zen2 上的 不同 端口(与英特尔不同),端口 FP2/3 和 FP0/1,所以理论上可以维持2/clock vaddpd vmulpd。由于循环承载依赖的延迟更短,如果调度不让一个 dep 链落后,8 个累加器足以隐藏 vaddpd 延迟。 (但至少乘法不会从中窃取周期。)

Zen2 的前端有 5 个指令宽(如果有任何多 uop 指令,则为 6 uops),并且它可以将内存源指令解码为单个 uop。所以它很可能在每次乘法和非 FMA 版本中做 2/clock。

如果您可以展开 10 或 12,这可能会隐藏足够的 FMA 延迟并使其与非 FMA 版本相同,但功耗更低且更适合 SMT 编码 运行ning另一个逻辑核心。 (10 = 5 x 2 只是 勉强 就够了,这意味着任何调度缺陷都会在关键路径上的 dep 链上失去进展。请参阅 进行一些测试英特尔。)

(相比之下,英特尔 Skylake 运行s vaddpd/vmulpd 在相同端口上与 vfma...pd 具有相同的延迟,均具有 4c 延迟,0.5c 吞吐量。)

我没有仔细查看您的代码,但 10 个 YMM 向量可能是接触两对缓存行与接触 5 条总行之间的权衡,如果空间预取器试图完成一个缓存行,这可能会更糟对齐对。或者可能还好。 12个YMM向量就是三对,应该没问题。

根据矩阵大小,无序执行可能能够在外循环的单独迭代之间重叠内循环 dep 链,特别是如果循环退出条件可以更快地执行并解决错误预测(如果有一)FP工作仍在进行中。对于相同的工作,使用更少的总 uops 是一个优势,有利于 FMA。