AVX 内在澄清,4x4 矩阵乘法奇数

AVX Intrinsic Clarification, 4x4 Matrix Multiplication Oddities

我在纸上画出了这个算法的长形式,在纸上它应该可以正常工作。我是 运行 对寄存器转换 (256/128/256) 很敏感,还是我真的在某处搞砸了算法结构?

为方便起见,我已将普通代码和 AVX 代码放在 Godbolt 查看器上,以便您随意查看生成的程序集。

标准代码 https://godbolt.org/g/v47RKH

我的 AVX 尝试 1: https://godbolt.org/g/oH1DpO

我的 AVX 尝试 2: https://godbolt.org/g/QFtdKr(减少了 5 个循环并减少了铸造需求,更易于阅读)

奇怪的是,SSE 代码使用的是标量运算,这让我感到困惑,因为这肯定可以通过水平广播、乘积和加法来加速。我想做的是将这个概念提升一个层次。

RHS 永远不需要改变,但本质上如果 LHS 是 {a, b, ..., p}, LHS 是 {1, 2, ..., 16},那么我们只需要 2 个寄存器来保存 RHS 的两半,然后 2 个寄存器来保存 LHS 的给定行,形式为 {a, a, a, a , b, b, b, b} 和 {c, c, c, c, d, d, d, d}。这是通过 2 次广播和 256/128/256 投射实现的。

我们得到了

的中间结果

{a*1, a*2, a*3, a*4, b*5, b*6, b*7, b*8} => row[0]

{c*9, c*10, c*11, c*12, d*13, d*14, d*15, d*16} => row[1]

展开一次 w.r.t LHS 所以我们生成

{e*1, ... f*8}, {g*9, ... h*16} => row[2], row[3]

接下来将 r0、r1 和 r2、r3 加在一起(保持 r0 和 r2 作为当前中间体)

最后,提取row[0]的高半部分到resHalf的低半部分,将row[2]的低半部分插入resHalf的高半部分,将row[2]的高半部分插入resHalf row[0] 的高一半,然后将 row[0] 添加到 resHalf.

根据所有权利,在迭代 i = 0

结束时,resHalf[0] 应该等于以下值

{a*1 + b*2 + c*3 + d*4, a*5 + b*6 + c*7 + d*8,

a*9 + b*10 + c*11 + d*12, a*13 + b*14 + c*15 + d*16,

e*1 + ... + h*4, e*5 + ... + h*8,

e*9 + ... + h*12, e*13 + ... + h*16}

然而,我的算法产生的结果如下:

2x {a*1 + c*3, a*5 + c*7, a*9 + c*11, a*13 + c*15},

2x {e*1 + g*3, e*5 + g*7, e*9 + g*11, e*13 + g*15}

更可怕的是,如果我在三元条件中交换 rhsHolders[0/1],它根本不会改变结果。就好像编译器忽略了交换和添加之一。 Clang 4 和 GCC 7 都这样做,那么我哪里搞砸了?

编辑:输出应该是 4 行 {10, 26, 42, 58},但我得到 {4, 12, 20, 28}

这几乎是我昨天在 SO 上的回答的复制粘贴:)

试试这个

void MatMul(const float* __restrict lhs , const float* __restrict rhs , float* __restrict out ) 
{
  lhs = reinterpret_cast<float*>(__builtin_assume_aligned (lhs, 32));
  rhs = reinterpret_cast<float*>(__builtin_assume_aligned (rhs, 32));
  out = reinterpret_cast<float*>(__builtin_assume_aligned (out, 32));
  for(int i = 0; i < 4; i++){
    for(int j = 0; j < 4; j++){
      for (int k = 0; k < 4; k++){
        out[i*4 + j] += lhs[i*4 + k]*rhs[k*4 + i];
      }
    }     
  }     
}

使用以下之一编译(衡量哪一个对您来说最快)

-O3 -mavx
-O3 -mavx2
-O3 -mavx2 -mfma
-O3 -mavx2 -mfma -ffast-math

这在 GCC 下有效(我的意思是矢量化),cLANG 出于某种原因未能这样做。 GCC 也将展开循环。

The SSE code oddly enough is using scalar operations, which boggles my mind since that can definitely be accelerated with horizontal broadcasts, muls, and adds.

你是指编译器生成的汇编代码吗? clang4.0 和 gcc7.1 输出中 MatMul() 中的所有 AVX 指令都在 ymm 向量上运行。除了 clang 愚蠢的广播加载:它执行标量加载,然后执行单独的 AVX2 广播指令,这非常糟糕,因为英特尔 CPU 将广播加载作为单 uop ALU 指令处理。加载端口本身可以进行广播。但是如果源是一个寄存器,它需要一个 ALU uop 用于 shuffle 端口。

    vmovss  xmm5, dword ptr [rdi + 24] # xmm5 = mem[0],zero,zero,zero
    vbroadcastss    xmm5, xmm5

clang 的实际输出(上图)与 gcc 使用的 AVX1 vbroadcastss xmm5, [rdi + 24] 相比真的很愚蠢。

main() 中,clang 确实发出标量运算

因为你的输入矩阵都是编译时常量,唯一的谜团是为什么它没有优化到 cout << "a long string with the numbers already formatted\n";,或者至少优化掉所有数学并准备好 double 结果以供打印。 (是的,它们在打印循环中使用 vcvtss2sdfloat 转换为 double。)

它通过一些内在的洗牌和数学优化,在编译时进行。我猜 clang 在洗牌的某个地方迷路了,但仍然发出了一些数学运算。它们是标量的事实可能表明它在编译时没有做太多工作,但它没有重新排序以对其进行矢量化。

请注意,某些常量未出现在源代码中,并且它们在内存中未按升序排列。

...
.LCPI1_5:
        .long   1092616192              # float 10
.LCPI1_6:
        .long   1101004800              # float 20
.LCPI1_7:
        .long   1098907648              # float 16
...

clang 将浮点值放在位模式的整数表示之后的注释中真是太好了。


or did I actually mess up the algorithm structure somewhere?

好吧,这部分实现看起来完全是假的。您从 rows[j] 初始化 lowerHalf,但随后在下一个语句中覆盖该值。

__m128 lowerHalf = _mm256_castps256_ps128(rows[j]);
    lowerHalf = _mm_broadcast_ss(&lhs[offset+2*j]);

然后你用 rows[j] undefined 的上 128b 通道做一个 256b 乘法。

    rows[j] = _mm256_castps128_ps256(lowerHalf);
    rows[j] = _mm256_mul_ps(rows[j], (chooser) ? rhsHolders[0] : rhsHolders[1]);

在 gcc 和 clang 的 asm 中,上面的通道全为零(因为它们明显选择使用标量 -> xmm 广播最后写入的 ymm 寄存器,它隐式地零扩展到最大向量宽度)。请注意 _mm256_castps128_ps256 不保证零扩展。除非 __m128 本身是来自 256b 或更宽向量的 extract/cast 的结果,否则很有可能,但它是未定义的。请参阅 How to clear the upper 128 bits of __m256 value? 以了解需要在矢量中将上车道置零的情况。

无论如何,这意味着您将从 128b 向量乘法 (vmulps xmm, xmm, xmm) 中得到相同的结果:在这些指令之后,上面的 4 个元素将全部为零(或 NaN)

    vbroadcastss    xmm0, DWORD PTR [rdi+40]
    vmulps  ymm0, ymm2, ymm0

这种 asm 输出(来自 gcc7.1)极不可能是正确的 matmul 实现的一部分。

我没有仔细查看您在源代码中到底想做什么,但我认为它不完全是这个。


And what's scarier still is if I swap rhsHolders[0/1] in the ternary conditional, it doesn't change the results at all. It's as though the compiler is ignoring one of the swaps and adds.

如果更改源代码中的某些内容未在 asm 输出中产生您期望的更改,则表明您可能弄错了源代码,并且正在优化某些内容。有时我 copy/paste 一个内部函数而忘记在新行中更改输入变量,所以我的函数忽略了它的一些计算结果并使用另一个两次。