如何将此汇编代码转换为内部代码?

How to convert this assembly code to intrinsic code?

下面好像是intrinsic,不过我对intrinsic functions不熟悉。请帮我转换真正的代码。特别是, testFunc() 对我来说更加模糊。 我猜它也是两个浮点向量的点积,但是,标签 Lrep 和 Lexit 让我感到困惑。 请帮我想清楚。 内在函数可用于移动处理器?

void testFunc(int M, int N, int K, float* A, float* B, float* C)
{
    float *a;
    float *b = new float[K*N];
    float *pointb = B;
    float *bb;
    float *answer = C;
    float c[8];

    for (int j = 0, k; j < K; j++) {
        bb = b + j;
        for (k = N / 8; k > 0; k--) {
            *bb = *pointb++; bb += K;
            *bb = *pointb++; bb += K;
            *bb = *pointb++; bb += K;
            *bb = *pointb++; bb += K;
            *bb = *pointb++; bb += K;
            *bb = *pointb++; bb += K;
            *bb = *pointb++; bb += K;
            *bb = *pointb++; bb += K;
        }
        for (k = N / 8 * 8; k < N; k++) {
            *bb = *pointb++; bb += K;
        }
    }

    int K8 = K / 8 * 8;

    for (int i = 0; i < M; i++) for (int k = 0; k < N; k++) {
        a = A + i * K;
        bb = b + k * K;
        __asm {
            mov             esi, K8;
            sub             esi, 8;
            shl             esi, 2;
            xor             edi, edi;
            mov             edx, a;
            mov             ebx, bb;
            vxorps          ymm3, ymm3, ymm3;
        Lrep:
            cmp             edi, esi;
            jg              Lexit;
            vmovups         ymm0, ymmword ptr[edx + edi];
            vfmadd231ps     ymm3, ymm0, ymmword ptr[ebx + edi];
            add             edi, 32;
            jmp             Lrep;
        Lexit:
            vmovups         ymmword ptr[c], ymm3;
        }

        for (int j = K8; j < K; ) {
            *c += *(a + j) * *(bb + j); j++;
        }

        *answer = (c[0] + c[1] + c[2] + c[3] + c[4] + c[5] + c[6] + c[7]);
        answer++;
    }
}

pA = A;
for (k = 0; k < K; k++) {
    pC = C;
    for (i = 0; i < M; i++) {
        pA = A + i * K + k;
        pB = B + k * N;
        for (j = N / 32; j > 0; j--) {
            _asm {
                mov             eax, pC;
                mov             ebx, pA;
                mov             ecx, pB;
                vmovups         ymm0, ymmword ptr[eax];
                vmovss          xmm1, dword ptr[ebx];
                vbroadcastss    ymm4, xmm1;
                vmovups         ymm2, ymmword ptr[ecx];
                vfmadd231ps     ymm0, ymm4, ymm2;
                vmovups         ymmword ptr[eax], ymm0;
            }
            pC += 8; pB += 8;
            _asm {
                mov             eax, pC;
                mov             ebx, pA;
                mov             ecx, pB;
                vmovups         ymm0, ymmword ptr[eax];
                vmovss          xmm1, dword ptr[ebx];
                vbroadcastss    ymm4, xmm1;
                vmovups         ymm2, ymmword ptr[ecx];
                vfmadd231ps     ymm0, ymm4, ymm2;
                vmovups         ymmword ptr[eax], ymm0;
            }
            pC += 8; pB += 8;
            _asm {
                mov             eax, pC;
                mov             ebx, pA;
                mov             ecx, pB;
                vmovups         ymm0, ymmword ptr[eax];
                vmovss          xmm1, dword ptr[ebx];
                vbroadcastss    ymm4, xmm1;
                vmovups         ymm2, ymmword ptr[ecx];
                vfmadd231ps     ymm0, ymm4, ymm2;
                vmovups         ymmword ptr[eax], ymm0;
            }
            pC += 8; pB += 8;
            _asm {
                mov             eax, pC;
                mov             ebx, pA;
                mov             ecx, pB;
                vmovups         ymm0, ymmword ptr[eax];
                vmovss          xmm1, dword ptr[ebx];
                vbroadcastss    ymm4, xmm1;
                vmovups         ymm2, ymmword ptr[ecx];
                vfmadd231ps     ymm0, ymm4, ymm2;
                vmovups         ymmword ptr[eax], ymm0;
            }
            pC += 8; pB += 8;
        }
        for (j = N / 32 * 32; j < N; j++) {
            *pC += *pA * *pB;
            pC += 1; pB += 1;
        }
    }
}

我会做前几行来让你开始,但实际上,如果你不能阅读汇编,你需要参考英特尔 CPU 手册才能破译它。

mov             esi, K8;
sub             esi, 8;
shl             esi, 2;
xor             edi, edi;
mov             edx, a;
mov             ebx, bb;
mov             esi, K8
  1. 将K8的内容复制到esi中
  2. 用 easi 中的值减去 8
  3. 将esi左移2位并将结果复制到esi
  4. 针对 edi 对 edi 应用异或运算(这将为 0,如果您了解二进制和寄存器的工作原理,原因就很清楚)
  5. 将 a 的内容复制到 edx
  6. 将bb的内容复制到ebx中
  7. 将K8的内容复制到esi中

从这里开始,您需要熟悉与您的问题相关的二进制和基本 cpu 体系结构以及汇编语言操作数,具体取决于您的知识所在。一旦你可以阅读每一行,然后你就可以破译块,最后是程序。

2 个向量加载(从 2 个数组中的相同位置)将 FMA 送入向量累加器对我来说闻起来像 dot-product。

我没有查看 asm 参考手册以了解目标操作数是总和而不是被乘数的 1,但这是有意义的方式。

triple-nested 循环看起来像矩阵乘法。它广播 1 个输入,同时从另一个进行矢量加载以馈送 FMA,因此它可能正在为输出行生成结果的 SIMD 矢量。

为此使用 MSVC 内联 asm 语法非常糟糕;它只能通过内存操作数接受输入,因此它强制在每个 asm 块之间重新加载 + 存储。如果要展开,请使用一个大的 asm 语句并在寻址模式中使用位移。


IDK 为什么 dot-produce 循环编写效率低下(循环内有条件和无条件分支),并且没有使用多个累加器展开。几乎违背了 hand-coding 在 asm 中的目的。有关如何使用多个累加器隐藏 FMA 延迟的信息,请参阅 。或者在展开+矢量化纯 C 循环时让 clang 为您完成。

我也不知道为什么它没有 horizontal-sum 结果,而是用 vmovups [c], ymm3 将它存储到内存中。似乎毫无意义。我想调用者必须从内存中重新加载并求和,或者您可以将函数声明为返回 __m256 向量并忽略存储。


无论如何,您显然可以用标量 C 代码编写 dot-product,也许使用 math.h 中的 fma(a[i], b[i], sum) 来复制 asm 不舍入临时结果的行为。

或者用 sum = _mm256_fmadd_ps(_mm256_loadu_ps(a[i]), _mm256_loadu_ps(b[i]), sum); 之类的内在函数复制手动矢量化。 (参见 Intel's intrinsics guide)。

在内部函数中,这段代码重复了 4 次。

{
// vmovups         ymm0, ymmword ptr[eax];
__m256 tempC = _mm256_loadu_ps((float*)pC);

// vmovss          xmm1, dword ptr[ebx];
// vbroadcastss    ymm4, xmm1;
__m256 tempA = _mm256_set1_ps(*pA);

// vmovups         ymm2, ymmword ptr[ecx];
__m256 tempB = _mm256_loadu_ps((float*)pB);

// vfmadd231ps     ymm0, ymm4, ymm2;
__m256 result = _mm256_fmadd_ps(tempA, tempB, tempC);

// vmovups         ymmword ptr[eax], ymm0;
_mm256_storeu_ps(pC, result);
}

pC += 8; pB += 8;

不过,不断地从 pA 广播相同的值似乎有点多余。