C++ 中一个复杂问题的手册 vectorization/SSE

Manual vectorization/SSE for a complex problem in C++

我想加快我的算法,它是一个 objective 函数 f(x)。问题维度是5000。我已经在代码中引入了很多改进,但计算时间仍然不符合我的预期。

大部分数据集动态分配为 (float*)_mm_malloc(N_h*sizeof(float),16)。在存在 "long" for 循环的 objective 函数中,我成功应用了 _mm_mul_ps_mm_rcp_ps_mm_store_ps ... __m128Var 变量等。我还引入了线程 (_beginthreadex) 来加速最慢的代码。 但是有一部分代码不能轻易矢量化...... 我附上了问题最严重的代码(最慢的计算),但我仍然无法做出改进(提醒,这是来自更大计算的代码的一部分,但我的问题可以从这里看出)。我期待矢量计算,但我对每一行代码进行了简单的计算(很多 MOVSSMULSSSUBSS ...汇编代码中的等)。有人可以给我提示什么是问题吗?

我在 Windows 机器上使用 MinGW GCC-8.2.0-3 编译器,带有 -O3 -march=native -ffast-math 标志。

#include <immintrin.h>
#include "math.h"
#define N_h 5000

float* x_vec;   // allocated as: (float*)_mm_malloc(N_h*sizeof(float),16);
float* data0; //allocated as: (float*)_mm_malloc(N_h*sizeof(float),16);
float* data1; //allocated as: (float*)_mm_malloc(N_h*sizeof(float),16);
float* data2; //allocated as: (float*)_mm_malloc(N_h*sizeof(float),16);
float* data3; //allocated as: (float*)_mm_malloc(N_h*sizeof(float),16);

int main() 
{
    float* q_vec = (float*)_mm_malloc(8*sizeof(float),16);
    float* xx_vec = (float*)_mm_malloc(8*sizeof(float),16);
    float* cP_vec = (float*)_mm_malloc(8*sizeof(float),16);
    float* xPtr = x_vec;
    float* f32Ptr;
    float c0;
    int n = N_h;
    int sum = 0;

    while(n > 0)
    {
        int k=1;
        n-=8;

        cP_vec[0] = 1;
        cP_vec[1] = 1;
        cP_vec[2] = 1;
        cP_vec[3] = 1;
        cP_vec[4] = 1;
        cP_vec[5] = 1;
        cP_vec[6] = 1;
        cP_vec[7] = 1;
        //preload of x data shall be done with vector preload, currently it is row-by-row **MOVS**
        xx_vec[0] = *xPtr++;
        xx_vec[1] = *xPtr++;
        xx_vec[2] = *xPtr++;
        xx_vec[3] = *xPtr++;
        xx_vec[4] = *xPtr++;
        xx_vec[5] = *xPtr++;
        xx_vec[6] = *xPtr++;
        xx_vec[7] = *xPtr++;

        c0 = data0[k];
        //I am expecting vector subtraction here, but each of the row generates almost same assembly code
        q_vec[0] = xx_vec[0] - c0;
        q_vec[1] = xx_vec[1] - c0;
        q_vec[2] = xx_vec[2] - c0;
        q_vec[3] = xx_vec[3] - c0;
        q_vec[4] = xx_vec[4] - c0;
        q_vec[5] = xx_vec[5] - c0;
        q_vec[6] = xx_vec[6] - c0;
        q_vec[7] = xx_vec[7] - c0;
        //if I create more internal variable for all of the multiplication, does it help?
        cP_vec[0] = cP_vec[0] * data1[k] * exp(-pow(q_vec[0], 2.0f) * data2[k]);
        cP_vec[1] = cP_vec[1] * data1[k] * exp(-pow(q_vec[1], 2.0f) * data2[k]);
        cP_vec[2] = cP_vec[2] * data1[k] * exp(-pow(q_vec[2], 2.0f) * data2[k]);
        cP_vec[3] = cP_vec[3] * data1[k] * exp(-pow(q_vec[3], 2.0f) * data2[k]);
        cP_vec[4] = cP_vec[4] * data1[k] * exp(-pow(q_vec[4], 2.0f) * data2[k]);
        cP_vec[5] = cP_vec[5] * data1[k] * exp(-pow(q_vec[5], 2.0f) * data2[k]);
        cP_vec[6] = cP_vec[6] * data1[k] * exp(-pow(q_vec[6], 2.0f) * data2[k]);
        cP_vec[7] = cP_vec[7] * data1[k] * exp(-pow(q_vec[7], 2.0f) * data2[k]);
        k++;
        f32Ptr = &data3[k];
        for (int j =1; j <= 5; j++) //the index of this for is defined by a variable in my application, so it is not a constant
        {
            c0 = data0[k];
            //here the subtraction and multiplication is not vectoritzed
            q_vec[0] = (xx_vec[0] - c0) * (*f32Ptr);
            q_vec[1] = (xx_vec[1] - c0) * (*f32Ptr);
            q_vec[2] = (xx_vec[2] - c0) * (*f32Ptr);
            q_vec[3] = (xx_vec[3] - c0) * (*f32Ptr);
            q_vec[4] = (xx_vec[4] - c0) * (*f32Ptr);
            q_vec[5] = (xx_vec[5] - c0) * (*f32Ptr);
            q_vec[6] = (xx_vec[6] - c0) * (*f32Ptr);
            q_vec[7] = (xx_vec[7] - c0) * (*f32Ptr);

            q_vec[0] = (0.5f - 0.5f*erf( q_vec[0] ) );
            q_vec[1] = (0.5f - 0.5f*erf( q_vec[1] ) );
            q_vec[2] = (0.5f - 0.5f*erf( q_vec[2] ) );
            q_vec[3] = (0.5f - 0.5f*erf( q_vec[3] ) );
            q_vec[4] = (0.5f - 0.5f*erf( q_vec[4] ) );
            q_vec[5] = (0.5f - 0.5f*erf( q_vec[5] ) );
            q_vec[6] = (0.5f - 0.5f*erf( q_vec[6] ) );
            q_vec[7] = (0.5f - 0.5f*erf( q_vec[7] ) );
            //here the multiplication is not vectorized...
            cP_vec[0] = cP_vec[0] * q_vec[0];
            cP_vec[1] = cP_vec[1] * q_vec[1];
            cP_vec[2] = cP_vec[2] * q_vec[2];
            cP_vec[3] = cP_vec[3] * q_vec[3];
            cP_vec[4] = cP_vec[4] * q_vec[4];
            cP_vec[5] = cP_vec[5] * q_vec[5];
            cP_vec[6] = cP_vec[6] * q_vec[6];
            cP_vec[7] = cP_vec[7] * q_vec[7];
            f32Ptr++;
            k++;
        }
        sum += cP_vec[0];
        sum += cP_vec[1];
        sum += cP_vec[2];
        sum += cP_vec[3];
        sum += cP_vec[4];
        sum += cP_vec[5];
        sum += cP_vec[6];
        sum += cP_vec[7];
    }
    return 0;
}

在Godbolt上可以看到汇编代码: https://godbolt.org/z/wbkNAk


更新:

我已经实现了一些 SSE 计算。加速是大约。 x1.10-1.15 远低于预期...... 我在 main() 中做错了什么吗?

#include <immintrin.h>
#include "math.h"
#define N_h 5000

#define EXP_TABLE_SIZE 10
static const __m128 M128_1 = {1.0, 1.0, 1.0, 1.0};

float* x_vec;   // allocated as: (float*)_mm_malloc(N_h*sizeof(float),16);
float* data0; //allocated as: (float*)_mm_malloc(N_h*sizeof(float),16);
float* data1; //allocated as: (float*)_mm_malloc(N_h*sizeof(float),16);
float* data2; //allocated as: (float*)_mm_malloc(N_h*sizeof(float),16);
float* data3; //allocated as: (float*)_mm_malloc(N_h*sizeof(float),16);

typedef struct ExpVar {
    enum {
        s = EXP_TABLE_SIZE,
        n = 1 << s,
        f88 = 0x42b00000 /* 88.0 */
    };
    float minX[8];
    float maxX[8];
    float a[8];
    float b[8];
    float f1[8];
    unsigned int i127s[8];
    unsigned int mask_s[8];
    unsigned int i7fffffff[8];
    unsigned int tbl[n];
    union fi {
        float f;
        unsigned int i;
    };
    ExpVar()
    {
        float log_2 = ::logf(2.0f);
        for (int i = 0; i < 8; i++) {
            maxX[i] = 88;
            minX[i] = -88;
            a[i] = n / log_2;
            b[i] = log_2 / n;
            f1[i] = 1.0f;
            i127s[i] = 127 << s;
            i7fffffff[i] = 0x7fffffff;
            mask_s[i] = mask(s);
        }

        for (int i = 0; i < n; i++) {
            float y = pow(2.0f, (float)i / n);
            fi fi;
            fi.f = y;
            tbl[i] = fi.i & mask(23);
        }
    }
    inline unsigned int mask(int x)
    {
        return (1U << x) - 1;
    }
};

inline __m128 exp_ps(__m128 x, ExpVar* expVar)
{
    __m128i limit = _mm_castps_si128(_mm_and_ps(x, *(__m128*)expVar->i7fffffff));
    int over = _mm_movemask_epi8(_mm_cmpgt_epi32(limit, *(__m128i*)expVar->maxX));
    if (over) {
        x = _mm_min_ps(x, _mm_load_ps(expVar->maxX));
        x = _mm_max_ps(x, _mm_load_ps(expVar->minX));
    }

    __m128i r = _mm_cvtps_epi32(_mm_mul_ps(x, *(__m128*)(expVar->a)));
    __m128 t = _mm_sub_ps(x, _mm_mul_ps(_mm_cvtepi32_ps(r), *(__m128*)(expVar->b)));
    t = _mm_add_ps(t, *(__m128*)(expVar->f1));

    __m128i v4 = _mm_and_si128(r, *(__m128i*)(expVar->mask_s));
    __m128i u4 = _mm_add_epi32(r, *(__m128i*)(expVar->i127s));
    u4 = _mm_srli_epi32(u4, expVar->s);
    u4 = _mm_slli_epi32(u4, 23);

    unsigned int v0, v1, v2, v3;
    v0 = _mm_cvtsi128_si32(v4);
    v1 = _mm_extract_epi16(v4, 2);
    v2 = _mm_extract_epi16(v4, 4);
    v3 = _mm_extract_epi16(v4, 6);
    __m128 t0, t1, t2, t3;

    t0 = _mm_castsi128_ps(_mm_set1_epi32(expVar->tbl[v0]));
    t1 = _mm_castsi128_ps(_mm_set1_epi32(expVar->tbl[v1]));
    t2 = _mm_castsi128_ps(_mm_set1_epi32(expVar->tbl[v2]));
    t3 = _mm_castsi128_ps(_mm_set1_epi32(expVar->tbl[v3]));

    t1 = _mm_movelh_ps(t1, t3);
    t1 = _mm_castsi128_ps(_mm_slli_epi64(_mm_castps_si128(t1), 32));
    t0 = _mm_movelh_ps(t0, t2);
    t0 = _mm_castsi128_ps(_mm_srli_epi64(_mm_castps_si128(t0), 32));
    t0 = _mm_or_ps(t0, t1);

    t0 = _mm_or_ps(t0, _mm_castsi128_ps(u4));

    t = _mm_mul_ps(t, t0);

    return t;
}

int main() 
{
    float* q_vec = (float*)_mm_malloc(8*sizeof(float),16);
    float* xx_vec = (float*)_mm_malloc(8*sizeof(float),16);
    float* cP_vec = (float*)_mm_malloc(8*sizeof(float),16);
    float* xPtr = x_vec;
    float* f32Ptr;
    __m128 c0,c1;
    __m128* m128Var1;
    __m128* m128Var2;
    float* f32Ptr1;
    float* f32Ptr2;
    int n = N_h;
    int sum = 0;
    ExpVar expVar;

    while(n > 0)
    {
        int k=1;
        n-=8;

        //cP_vec[0] = 1;
        f32Ptr1 = cP_vec;
        _mm_store_ps(f32Ptr1,M128_1);
        f32Ptr1+=4;
        _mm_store_ps(f32Ptr1,M128_1);
        //preload x data
        //xx_vec[0] = *xPtr++;
        f32Ptr1 = xx_vec;
        m128Var1 = (__m128*)xPtr;
        _mm_store_ps(f32Ptr1,*m128Var1);
        m128Var1++;
        xPtr+=4;
        f32Ptr1+=4;
        m128Var1 = (__m128*)xPtr;
        _mm_store_ps(f32Ptr1,*m128Var1);
        xPtr+=4;

        c0 = _mm_set1_ps(data0[k]);
        m128Var1 = (__m128*)xx_vec;
        f32Ptr1 = q_vec;
        _mm_store_ps(f32Ptr1, _mm_sub_ps(*m128Var1, c0) );
        m128Var1++;
        f32Ptr1+=4;
        _mm_store_ps(f32Ptr1, _mm_sub_ps(*m128Var1, c0) );
        //calc -pow(q_vec[0], 2.0f)
        f32Ptr1 = q_vec;
        m128Var1 = (__m128*)q_vec;
        _mm_store_ps(f32Ptr1, _mm_mul_ps(*m128Var1, *m128Var1) );
        m128Var1++;
        f32Ptr1+=4;
        _mm_store_ps(f32Ptr1, _mm_mul_ps(*m128Var1, *m128Var1) );
        m128Var1 = (__m128*)q_vec;
        *m128Var1 = _mm_xor_ps(*m128Var1, _mm_set1_ps(-0.0));
        m128Var1++;
        *m128Var1 = _mm_xor_ps(*m128Var1, _mm_set1_ps(-0.0));
        //-pow(q_vec[0], 2.0f) * data2[k]

        c0 = _mm_set1_ps(data2[k]);
        f32Ptr1 = q_vec;
        m128Var1 = (__m128*)q_vec;
        _mm_store_ps(f32Ptr1, _mm_mul_ps(*m128Var1, c0) );
        m128Var1++;
        f32Ptr1+=4;
        _mm_store_ps(f32Ptr1, _mm_mul_ps(*m128Var1, c0) );
        m128Var1 = (__m128*)q_vec;
        //calc exp(x)
        *m128Var1 = exp_ps(*m128Var1,&expVar);
        m128Var1++;
        *m128Var1 = exp_ps(*m128Var1,&expVar);
        //data1[k] * exp(x)
        c0 = _mm_set1_ps(data1[k]);
        f32Ptr1 = q_vec;
        m128Var1 = (__m128*)q_vec;
        _mm_store_ps(f32Ptr1, _mm_mul_ps(*m128Var1, c0) );
        m128Var1++;
        f32Ptr1+=4;
        _mm_store_ps(f32Ptr1, _mm_mul_ps(*m128Var1, c0) );
        //cP_vec[0] * data1[k] * exp(x)
        f32Ptr1 = cP_vec;
        m128Var1 = (__m128*)cP_vec;
        m128Var2 = (__m128*)q_vec;
        _mm_store_ps(f32Ptr1, _mm_mul_ps(*m128Var1, *m128Var2) );
        m128Var1++;
        m128Var2++;
        f32Ptr1+=4;
        _mm_store_ps(f32Ptr1, _mm_mul_ps(*m128Var1, *m128Var2) );

        k++;
        for (int j =1; j <= 5; j++)
        {
            c0 = _mm_set1_ps(data0[k]);
            c1 = _mm_set1_ps(data3[k]);
            m128Var1 = (__m128*)xx_vec;
            m128Var2 = (__m128*)q_vec;
            f32Ptr1 = q_vec;
            _mm_store_ps(f32Ptr1, _mm_sub_ps(*m128Var1, c0) );
            _mm_store_ps(f32Ptr1, _mm_mul_ps(*m128Var2, c1) );
            m128Var1++;
            m128Var2++;
            f32Ptr1+=4;
            _mm_store_ps(f32Ptr1, _mm_sub_ps(*m128Var1, c0) );
            _mm_store_ps(f32Ptr1, _mm_mul_ps(*m128Var2, c1) );

            q_vec[0] = (0.5f - 0.5f*erf( q_vec[0] ) );
            q_vec[1] = (0.5f - 0.5f*erf( q_vec[1] ) );
            q_vec[2] = (0.5f - 0.5f*erf( q_vec[2] ) );
            q_vec[3] = (0.5f - 0.5f*erf( q_vec[3] ) );
            q_vec[4] = (0.5f - 0.5f*erf( q_vec[4] ) );
            q_vec[5] = (0.5f - 0.5f*erf( q_vec[5] ) );
            q_vec[6] = (0.5f - 0.5f*erf( q_vec[6] ) );
            q_vec[7] = (0.5f - 0.5f*erf( q_vec[7] ) );

            cP_vec[0] = cP_vec[0] * q_vec[0];
            cP_vec[1] = cP_vec[1] * q_vec[1];
            cP_vec[2] = cP_vec[2] * q_vec[2];
            cP_vec[3] = cP_vec[3] * q_vec[3];
            cP_vec[4] = cP_vec[4] * q_vec[4];
            cP_vec[5] = cP_vec[5] * q_vec[5];
            cP_vec[6] = cP_vec[6] * q_vec[6];
            cP_vec[7] = cP_vec[7] * q_vec[7];
            k++;
        }
        sum += cP_vec[0];
        sum += cP_vec[1];
        sum += cP_vec[2];
        sum += cP_vec[3];
        sum += cP_vec[4];
        sum += cP_vec[5];
        sum += cP_vec[6];
        sum += cP_vec[7];
    }
    return 0;
}

https://godbolt.org/z/N7K6j0

代码非常奇怪,所有这些显式存储到内存中而不是普通的旧变量。我试图让它不那么奇怪,并添加了矢量化 erf,这是 主要计算 。因为我不知道这段代码应该做什么,所以我无法真正测试它,除了性能,它确实变得更好了。

while (n > 0)
{
    int k = 1;
    n -= 8;

    //preload x data
    __m128 x_0 = _mm_load_ps(xPtr);
    __m128 x_1 = _mm_load_ps(xPtr + 4);
    xPtr += 8;

    __m128 c0 = _mm_set1_ps(data0[k]);
    __m128 q_0 = _mm_sub_ps(x_0, c0);
    __m128 q_1 = _mm_sub_ps(x_1, c0);
    //pow(q_vec, 2.0f)
    __m128 t_0 = _mm_mul_ps(q_0, q_0);
    __m128 t_1 = _mm_mul_ps(q_1, q_1);
    //-pow(q_vec[0], 2.0f) * data2[k]
    __m128 neg_data2k = _mm_xor_ps(_mm_set1_ps(data2[k]), _mm_set1_ps(-0.0));
    t_0 = _mm_mul_ps(t_0, neg_data2k);
    t_1 = _mm_mul_ps(t_1, neg_data2k);

    //exp(-pow(q_vec[0], 2.0f) * data2[k])
    t_0 = fast_exp_sse(t_0);
    t_1 = fast_exp_sse(t_1);
    //cP = data1[k] * exp(...)
    c0 = _mm_set1_ps(data1[k]);
    __m128 cP_0 = _mm_mul_ps(t_0, c0);
    __m128 cP_1 = _mm_mul_ps(t_1, c0);

    k++;
    for (int j = 1; j <= 5; j++)
    {
        __m128 data0k = _mm_set1_ps(data0[k]);
        __m128 data3k = _mm_set1_ps(data3[k]);
        // q = (x - data0k) * data3k;
        q_0 = _mm_mul_ps(_mm_sub_ps(x_0, data0k), data3k);
        q_1 = _mm_mul_ps(_mm_sub_ps(x_1, data0k), data3k);

        // q = 0.5 - 0.5 * erf(q)
        __m128 half = _mm_set1_ps(0.5);
        q_0 = _mm_sub_ps(half, _mm_mul_ps(half, erf_sse(q_0)));
        q_1 = _mm_sub_ps(half, _mm_mul_ps(half, erf_sse(q_1)));

        // cP = cP * q;
        cP_0 = _mm_mul_ps(cP_0, q_0);
        cP_1 = _mm_mul_ps(cP_1, q_1);
        k++;
    }

    __m128 t = _mm_add_ps(cP_0, cP_1);
    t = _mm_hadd_ps(t, t);
    t = _mm_hadd_ps(t, t);
    sum += _mm_cvtss_f32(t);
}

对于 erf 我使用了:

__m128 erf_sse(__m128 x)
{
    __m128 a1 = _mm_set1_ps(0.0705230784);
    __m128 a2 = _mm_set1_ps(0.0422820123);
    __m128 a3 = _mm_set1_ps(0.0092705272);
    __m128 a4 = _mm_set1_ps(0.0001520143);
    __m128 a5 = _mm_set1_ps(0.0002765672);
    __m128 a6 = _mm_set1_ps(0.0000430638);
    __m128 one = _mm_set1_ps(1);
    __m128 p = _mm_add_ps(one,
        _mm_mul_ps(x, _mm_add_ps(a1,
            _mm_mul_ps(x, _mm_add_ps(a2,
                _mm_mul_ps(x, _mm_add_ps(a3,
                    _mm_mul_ps(x, _mm_add_ps(a4,
                        _mm_mul_ps(x, _mm_add_ps(a5,
                            _mm_mul_ps(x, a6))))))))))));
    p = _mm_mul_ps(p, p);
    p = _mm_mul_ps(p, p);
    p = _mm_mul_ps(p, p);
    p = _mm_mul_ps(p, p);
    return _mm_sub_ps(one, _mm_div_ps(one, p));
}

我不太确定这个,它只是一个从维基百科转录成 SSE instrinsics 的公式,使用 Horner 的方案来评估多项式。可能有更好的方法。

对于fast_exp_sse,指数提取和多项式近似的通常组合。进行大量查找 table 是破坏 SIMD 增益的好方法。