在 Raspberry Pi 4 上使用 ARM Neon Intrinsics 加速矩阵向量乘法

Accelerating matrix vector multiplication with ARM Neon Intrinsics on Raspberry Pi 4

我需要优化矩阵向量乘法。数据如下所示:

此例程还必须满足一些非功能要求:

这是我的(经过简化,为了便于阅读,我假设输入被完全屏蔽)代码,

// input_height = 90000
// input_width = 81

for (uint32_t y = 0; y < input_height; y += 4) {
    float32x4_t sum0 = vmovq_n_f32(0);
    float32x4_t sum1 = vmovq_n_f32(0);
    float32x4_t sum2 = vmovq_n_f32(0);
    float32x4_t sum3 = vmovq_n_f32(0);

    for (uint32_t x = 0; x < input_width; x += 16) {
        float32x4x4_t A = load_matrix_transpose(kernel + x);

        float32x4x4_t B0 = load_matrix_transpose(input + y * input_width + x);
        float32x4x4_t B1 = load_matrix_transpose(input + (y + 1) * input_width + x);
        float32x4x4_t B2 = load_matrix_transpose(input + (y + 2) * input_width + x);
        float32x4x4_t B3 = load_matrix_transpose(input + (y + 3) * input_width + x);

        matrix_element_wise_multiplication(A, B0, sum0);
        matrix_element_wise_multiplication(A, B1, sum1);
        matrix_element_wise_multiplication(A, B2, sum2);
        matrix_element_wise_multiplication(A, B3, sum3);
    }

    output[y] = vaddvq_f32(sum0);
    output[y + 1] = vaddvq_f32(sum1);
    output[y + 2] = vaddvq_f32(sum2);
    output[y + 3] = vaddvq_f32(sum3);
}

其中load_matrix_transposematrix_element_wise_multiplication是以下函数:

inline float32x4x4_t load_matrix_transpose(float *a) {
    float32x4x4_t ret;

    ret.val[0] = simd_load(a);

    ret.val[1] = simd_load(a + 4);

    ret.val[2] = simd_load(a + 8);

    ret.val[3] = simd_load(a + 12);

    return ret;
}

inline void simd_matrix_element_wise_multiplication(float32x4x4_t & A, float32x4x4_t & B, float32x4x4_t & C) {
    C = vmlaq_f32(C, A.val[0], B.val[0]);
    C = vmlaq_f32(C, A.val[1], B.val[1]);
    C = vmlaq_f32(C, A.val[2], B.val[2]);
    C = vmlaq_f32(C, A.val[3], B.val[3]);
}

在我的 Rasperry Pi 4(ARMv8、8GB RAM、4 核)上,代码采用优化级别 -O3,大约 60ms

长 运行(多次循环),Neon 寄存器版本的速度正好是普通代码的两倍。

我的问题是,有没有进一步优化代码的方法?我尝试了很多东西,但无法对正常代码进行任何改进。

就优化而言,数据局部性是最高优先级,您应该了解寄存器容量,因为寄存器是迄今为止最快和最稀缺的资源。

aarch64: 32x128bit neon 寄存器(512 字节)
aarch32: 16x128bit neon 寄存器(256 字节)

一个 81x90000 的矩阵在转置时需要保存 90000 个中间值来进行乘法运算,并且由于 360000 字节不适合 512 字节的寄存器组,因此将会有大量的内存交换,这会转化为巨大的性能影响.
另一方面,向量的 4*81 字节正好适合 512 字节。

void matVecMult81x90000(float *pDst, float *pMat, float *pVec)
{
    register float32x4_t vec0_3, vec4_7, vec8_11, vec12_15, vec16_19, vec20_23, vec24_27, vec28_31, vec32_35, vec36_39, vec40_43, vec44_47, vec48_51, vec52_55, vec56_59, vec60_63, vec64_67, vec68_71, vec72_75, vec76_79, vec80;
    register float32x4_t mat0, mat1, mat2, mat3, mat4, rslt;
    register float32x2_t drslt;
    register uint32_t nRows = 90000;

    vec80 = vdupq_n_f32(0.0f);
    mat4 =vdupq_n_f32(0.0f);
    vec0_3 = vld1q_f32(pVec); pVec += 4;
    vec4_7 = vld1q_f32(pVec); pVec += 4;
    vec8_11 = vld1q_f32(pVec); pVec += 4;
    vec12_15 = vld1q_f32(pVec); pVec += 4;
    vec16_19 = vld1q_f32(pVec); pVec += 4;
    vec20_23 = vld1q_f32(pVec); pVec += 4;
    vec24_27 = vld1q_f32(pVec); pVec += 4;
    vec28_31 = vld1q_f32(pVec); pVec += 4;
    vec32_35 = vld1q_f32(pVec); pVec += 4;
    vec36_39 = vld1q_f32(pVec); pVec += 4;
    vec40_43 = vld1q_f32(pVec); pVec += 4;
    vec44_47 = vld1q_f32(pVec); pVec += 4;
    vec48_51 = vld1q_f32(pVec); pVec += 4;
    vec52_55 = vld1q_f32(pVec); pVec += 4;
    vec56_59 = vld1q_f32(pVec); pVec += 4;
    vec60_63 = vld1q_f32(pVec); pVec += 4;
    vec64_67 = vld1q_f32(pVec); pVec += 4;
    vec68_71 = vld1q_f32(pVec); pVec += 4;
    vec72_75 = vld1q_f32(pVec); pVec += 4;
    vec76_79 = vld1q_f32(pVec); pVec += 4;
    vld1q_lane_f32(pVec, vec80, 0);

    do {
        mat0 = vld1q_f32(pMat); pMat += 4;
        mat1 = vld1q_f32(pMat); pMat += 4;
        mat2 = vld1q_f32(pMat); pMat += 4;
        mat3 = vld1q_f32(pMat); pMat += 4;
        rslt = vmulq_f32(mat0, vec0_3);
        rslt += vmulq_f32(mat1, vec4_7);
        rslt += vmulq_f32(mat2, vec8_11);
        rslt += vmulq_f32(mat3, vec12_15);

        mat0 = vld1q_f32(pMat); pMat += 4;
        mat1 = vld1q_f32(pMat); pMat += 4;
        mat2 = vld1q_f32(pMat); pMat += 4;
        mat3 = vld1q_f32(pMat); pMat += 4;
        rslt += vmulq_f32(mat0, vec16_19);
        rslt += vmulq_f32(mat1, vec20_23);
        rslt += vmulq_f32(mat2, vec24_27);
        rslt += vmulq_f32(mat3, vec28_31);

        mat0 = vld1q_f32(pMat); pMat += 4;
        mat1 = vld1q_f32(pMat); pMat += 4;
        mat2 = vld1q_f32(pMat); pMat += 4;
        mat3 = vld1q_f32(pMat); pMat += 4;
        rslt += vmulq_f32(mat0, vec32_35);
        rslt += vmulq_f32(mat1, vec36_39);
        rslt += vmulq_f32(mat2, vec40_43);
        rslt += vmulq_f32(mat3, vec44_47);

        mat0 = vld1q_f32(pMat); pMat += 4;
        mat1 = vld1q_f32(pMat); pMat += 4;
        mat2 = vld1q_f32(pMat); pMat += 4;
        mat3 = vld1q_f32(pMat); pMat += 4;
        rslt += vmulq_f32(mat0, vec48_51);
        rslt += vmulq_f32(mat1, vec52_55);
        rslt += vmulq_f32(mat2, vec56_59);
        rslt += vmulq_f32(mat3, vec60_63);

        mat0 = vld1q_f32(pMat); pMat += 4;
        mat1 = vld1q_f32(pMat); pMat += 4;
        mat2 = vld1q_f32(pMat); pMat += 4;
        mat3 = vld1q_f32(pMat); pMat += 4;
        vld1q_lane_f32(pMat, mat4, 0); pMat += 1;
        rslt += vmulq_f32(mat0, vec64_67);
        rslt += vmulq_f32(mat1, vec68_71);
        rslt += vmulq_f32(mat2, vec72_75);
        rslt += vmulq_f32(mat3, vec76_79);
        rslt += vmulq_f32(mat4, vec80);

        *pDst++ = vaddvq_f32(rslt);
    } while (--nRows);
}

不幸的是,编译器并不能很好地配合。 (GCC 和 Clang)
生成的代码显示了循环内 Vector 上的一些堆栈交换。 下面是没有任何堆栈交换的手写汇编中的相同函数:

    .arch   armv8-a
    .global     matVecMult81x90000_asm
    .text

.balign 64
.func
matVecMult81x90000_asm:
// init loop counter
    mov     w3, #90000 & 0xffff
    movk    w3, #90000>>16, lsl #16

// preserve registers
    stp     d8, d9, [sp, #-48]!
    stp     d10, d11, [sp, #1*16]
    stp     d12, d13, [sp, #2*16]

// load vectors
    ldp     q0, q1, [x2, #0*32]
    ldp     q2, q3, [x2, #1*32]
    ldp     q4, q5, [x2, #2*32]
    ldp     q6, q7, [x2, #3*32]
    ldp     q8, q9, [x2, #4*32]
    ldp     q10, q11, [x2, #5*32]
    ldp     q12, q13, [x2, #6*32]
    ldp     q16, q17, [x2, #7*32]
    ldp     q18, q19, [x2, #8*32]
    ldp     q20, q21, [x2, #9*32]
    ldr     s22, [x2, #10*32]

// loop
.balign 64
1:
    ldp     q24, q25, [x1, #0*32]
    ldp     q26, q27, [x1, #1*32]
    ldp     q28, q29, [x1, #2*32]
    ldp     q30, q31, [x1, #3*32]
    subs    w3, w3, #1

    fmul    v23.4s, v24.4s, v0.4s
    fmla    v23.4s, v25.4s, v1.4s
    fmla    v23.4s, v26.4s, v2.4s
    fmla    v23.4s, v27.4s, v3.4s
    fmla    v23.4s, v28.4s, v4.4s
    fmla    v23.4s, v29.4s, v5.4s
    fmla    v23.4s, v30.4s, v6.4s
    fmla    v23.4s, v31.4s, v7.4s

    ldp     q24, q25, [x1, #4*32]
    ldp     q26, q27, [x1, #5*32]
    ldp     q28, q29, [x1, #6*32]
    ldp     q30, q31, [x1, #7*32]

    fmla    v23.4s, v24.4s, v8.4s
    fmla    v23.4s, v25.4s, v9.4s
    fmla    v23.4s, v26.4s, v10.4s
    fmla    v23.4s, v27.4s, v11.4s
    fmla    v23.4s, v28.4s, v12.4s
    fmla    v23.4s, v29.4s, v13.4s
    fmla    v23.4s, v30.4s, v16.4s
    fmla    v23.4s, v31.4s, v17.4s

    ldp     q24, q25, [x1, #8*32]
    ldp     q26, q27, [x1, #9*32]
    ldr     s28, [x1, #10*32]

    fmla    v23.4s, v24.4s, v18.4s
    fmla    v23.4s, v25.4s, v19.4s
    fmla    v23.4s, v26.4s, v20.4s
    fmla    v23.4s, v27.4s, v21.4s
    fmla    v23.4s, v28.4s, v22.4s

    add     x1, x1, #81*4

    faddp   v23.4s, v23.4s, v23.4s
    faddp   v23.2s, v23.2s, v23.2s

    str     s23, [x0], #4
    b.ne    1b

.balign 8
//restore registers

    ldp     d10, d11, [sp, #1*16]
    ldp     d12, d13, [sp, #2*16]
    ldp     d8, d9, [sp], #48

// return
    ret

.endfunc
.end

RK3368测试结果:
Clang 内在函数:10.41 毫秒
装配:9.59ms

在这种情况下,编译器的表现并没有那么糟糕,但它们往往愚蠢得令人难以置信。强烈推荐学习汇编。

这里是 Jake 答案的优化。

使用 4 个累加器而不是一个累加器会有所帮助,因为 FMA 指令的延迟远高于吞吐量。根据 Cortex-A72 optimization guideFMLA 指令的延迟对于完整的东西是 7 个周期,或者当依赖于累加器时是 3 个周期(如果你想知道到底什么是 Q-form 和 D-form , Q 用于 16 字节向量,D 用于 8 字节向量)。吞吐量高得多,是1个周期,CPU每个周期可以运行一个FMA。

以下版本使用了 4 个独立的累加器而不是单个累加器,尽管我们在循环末尾需要 3 条额外的指令来对累加器求和,但应该会提高吞吐量。

我还使用了一些宏来帮助处理重复代码。未经测试。

void matVecMult81( float *pDst, const float *pMat, const float *pVec, size_t nRows = 90000 )
{
    // 30 vector registers in total; ARM64 has 32 of them, so we're good.
    float32x4_t vec0_3, vec4_7, vec8_11, vec12_15, vec16_19, vec20_23, vec24_27, vec28_31, vec32_35, vec36_39, vec40_43, vec44_47, vec48_51, vec52_55, vec56_59, vec60_63, vec64_67, vec68_71, vec72_75, vec76_79, vec80;
    float32x4_t mat0, mat1, mat2, mat3, mat4;
    float32x4_t res0, res1, res2, res3;

    vec80 = mat4 = vdupq_n_f32( 0.0f );
    // Load 16 numbers from pVec into 3 vector registers, incrementing the source pointer
#define LOAD_VEC_16( v0, v1, v2, v3 )      \
    v0 = vld1q_f32( pVec ); pVec += 4;     \
    v1 = vld1q_f32( pVec ); pVec += 4;     \
    v2 = vld1q_f32( pVec ); pVec += 4;     \
    v3 = vld1q_f32( pVec ); pVec += 4

    // Load the complete vector into registers using the above macro
    LOAD_VEC_16( vec0_3, vec4_7, vec8_11, vec12_15 );
    LOAD_VEC_16( vec16_19, vec20_23, vec24_27, vec28_31 );
    LOAD_VEC_16( vec32_35, vec36_39, vec40_43, vec44_47 );
    LOAD_VEC_16( vec48_51, vec52_55, vec56_59, vec60_63 );
    LOAD_VEC_16( vec64_67, vec68_71, vec72_75, vec76_79 );
    // Load the final scalar of the vector
    vec80 = vld1q_lane_f32( pVec, vec80, 0 );

#undef LOAD_VEC_16

    // Load 16 numbers from pMat into mat0 - mat3, incrementing the source pointer
#define LOAD_MATRIX_16()                         \
        mat0 = vld1q_f32( pMat ); pMat += 4;     \
        mat1 = vld1q_f32( pMat ); pMat += 4;     \
        mat2 = vld1q_f32( pMat ); pMat += 4;     \
        mat3 = vld1q_f32( pMat ); pMat += 4

    // Multiply 16 numbers in mat0 - mat3 by the specified pieces of the vector, and accumulate into res0 - res3
    // Multiple accumulators is critical for performance, 4 instructions produced by this macro don't have data dependencies between them.
#define HANDLE_BLOCK_16( v0, v1, v2, v3 )        \
        res0 = vfmaq_f32( res0, mat0, v0 );      \
        res1 = vfmaq_f32( res1, mat1, v1 );      \
        res2 = vfmaq_f32( res2, mat2, v2 );      \
        res3 = vfmaq_f32( res3, mat3, v3 )

    const float* const pMatEnd = pMat + nRows * 81;
    while( pMat < pMatEnd )
    {
        // Initial 16 elements only need multiplication.
        LOAD_MATRIX_16();
        res0 = vmulq_f32( mat0, vec0_3 );
        res1 = vmulq_f32( mat1, vec4_7 );
        res2 = vmulq_f32( mat2, vec8_11 );
        res3 = vmulq_f32( mat3, vec12_15 );

        // Handle the rest of the row using FMA instructions.
        LOAD_MATRIX_16();
        HANDLE_BLOCK_16( vec16_19, vec20_23, vec24_27, vec28_31 );

        LOAD_MATRIX_16();
        HANDLE_BLOCK_16( vec32_35, vec36_39, vec40_43, vec44_47 );

        LOAD_MATRIX_16();
        HANDLE_BLOCK_16( vec48_51, vec52_55, vec56_59, vec60_63 );

        // The final block of the row has 17 scalars instead of 16
        LOAD_MATRIX_16();
        mat4 = vld1q_lane_f32( pMat, mat4, 0 ); pMat++;

        HANDLE_BLOCK_16( vec64_67, vec68_71, vec72_75, vec76_79 );
        res0 = vfmaq_f32( res0, mat4, vec80 );

        // Vertically add 4 accumulators into res0
        res1 = vaddq_f32( res1, res2 );
        res0 = vaddq_f32( res3, res0 );
        res0 = vaddq_f32( res1, res0 );

        // Store the horizontal sum of the accumulator
        *pDst = vaddvq_f32( res0 );
        pDst++;
    }

#undef LOAD_MATRIX_16
#undef HANDLE_BLOCK_16
}

使用 GCC 10.1 从该源生成的程序集 looks more or less OK