水平 运行 差异和使用 SIMD/SSE 的条件更新?

Horizontal running diff and conditional update using SIMD/SSE?

我想向量化以下操作:

V[i+1] = max(V[i] - c, V[i+1]) for i=1 to n-1 (V[0] = 0)

对应的朴素伪代码为:

for (i=0; i < n; i++) {
  if (V[i]-c > V[i+1]) V[i+1] = V[i]-c
}

哪些 SIMD 指令可能有用?

这可以通过 SIMD 完成。解决方法类似于the solution for the prefix sum with SIMD.

在 SIMD 寄存器中,迭代次数为 O(Log2(simd_width))。每次迭代需要:一次移位、一次减法和一次最大值。例如,对于 SSE,它需要 Log2(4) = 2 次迭代。您可以像这样在四个元素上应用您的函数:

__m128i foo_SSE(__m128i x, int c) {
    __m128i t, c1, c2;
    c1 = _mm_set1_epi32(c);
    c2 = _mm_set1_epi32(2*c);

    t = _mm_slli_si128(x, 4);
    t = _mm_sub_epi32(t, c1);
    x = _mm_max_epi32(x, t);

    t = _mm_slli_si128(x, 8);
    t = _mm_sub_epi32(t, c2);
    x = _mm_max_epi32(x, t);
    return x;
}

获得 SIMD 寄存器的结果后,您需要将 "carry" 应用于下一个寄存器。例如,假设您有一个包含八个元素的数组 a。你像这样加载 SSE 寄存器 x1x2

__m128i x1 = _mm_loadu_si128((__m128i*)&a[0]);
__m128i x2 = _mm_loadu_si128((__m128i*)&a[4]);

然后将您的函数应用于所有八个元素,您会做

__m128i t, s;
s = _mm_setr_epi32(c, 2*c, 3*c, 4*c);

x1 = foo_SSE(x1,c);
x2 = foo_SSE(x2,c);
t = _mm_shuffle_epi32(x1, 0xff);
t = _mm_sub_epi32(t,s);
x2 = _mm_max_epi32(x2,t);

请注意,c1c2s都是循环中的常量,因此只需计算一次。

通常,您可以将函数应用于无符号整数数组 a,就像使用 SSE 一样(n 是 4 的倍数):

void fill_SSE(int *a, int n, int c) {
    __m128i offset = _mm_setzero_si128();
    __m128i s = _mm_setr_epi32(c, 2*c, 3*c, 4*c);
    for(int i=0; i<n/4; i++) {
        __m128i x = _mm_loadu_si128((__m128i*)&a[4*i]);
        __m128i out = foo_SSE(x, c);
        out = _mm_max_epi32(out,offset);
        _mm_storeu_si128((__m128i*)&a[4*i], out);
        offset = _mm_shuffle_epi32(out, 0xff);
        offset = _mm_sub_epi32(offset,s);
    }
}

我继续分析这段 SSE 代码。 It's about 2.5 times faster than the serial version.

除了像 log2(simd_width) 那样进行之外,此方法的另一个主要优点是它打破了依赖链,因此多个 SIMD 操作可以同时进行(使用多个端口),而不是等待上一个结果。串行代码受延迟限制。

当前代码适用于无符号整数,但您可以将其推广到有符号整数和浮点数。

这是我用来测试的通用代码。在实现 SSE 版本之前,我创建了一堆抽象的 SIMD 函数来模拟 SIMD 硬件。

#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <x86intrin.h>
#include <omp.h>

__m128i foo_SSE(__m128i x, int c) {
    __m128i t, c1, c2;
    c1 = _mm_set1_epi32(c);
    c2 = _mm_set1_epi32(2*c);

    t = _mm_slli_si128(x, 4);
    t = _mm_sub_epi32(t, c1);
    x = _mm_max_epi32(x, t);

    t = _mm_slli_si128(x, 8);
    t = _mm_sub_epi32(t, c2);
    x = _mm_max_epi32(x, t);
    return x;
}

void foo(int *a, int n, int c) {
    for(int i=0; i<n-1; i++) {
        if(a[i]-c > a[i+1]) a[i+1] = a[i]-c;
    }
}

void broad(int *a, int n, int k) {
    for(int i=0; i<n; i++) a[i] = k;
}

void shiftr(int *a, int *b, int n, int m) {
    int i;
    for(i=0; i<m; i++) b[i] = a[i];
    for(; i<n; i++) b[i] = a[i-m];
}

/*
void shiftr(int *a, int *b, int n, int m) {
    int i;
    for(i=0; i<m; i++) b[i] = 0;
    for(; i<n; i++) b[i] = a[i-m];
}
*/

void sub(int *a, int n, int c) {
    for(int i=0; i<n; i++) a[i] -= c;
}


void max(int *a, int *b, int n) {
    for(int i=0; i<n; i++) if(b[i]>a[i]) a[i] = b[i];
}

void step(int *a, int n, int c) {
    for(int i=0; i<n; i++) {
        a[i] -= (i+1)*c;
    }
}

void foo2(int *a, int n, int c) {
    int b[n];
    for(int m=1; m<n; m*=2) {
        shiftr(a,b,n,m);
        sub(b, n, m*c);
        max(a,b,n);
        //printf("n %d, m %d; ", n,m ); for(int i=0; i<n; i++) printf("%2d ", b[i]); puts("");
    }
}

void fill(int *a, int n, int w, int c) {
    int b[w], offset[w];
    broad(offset, w, -1000);
    for(int i=0; i<n/w; i++) {
        for(int m=1; m<w; m*=2) {
            shiftr(&a[w*i],b,w,m);
            sub(b, w, m*c);
            max(&a[w*i],b,w);
        }
        max(&a[w*i],offset,w);
        broad(offset,w,a[w*i+w-1]);
        step(offset, w, c);
    }
}


void fill_SSE(int *a, int n, int c) {
    __m128i offset = _mm_setzero_si128();
    __m128i s = _mm_setr_epi32(c, 2*c, 3*c, 4*c);
    for(int i=0; i<n/4; i++) {
        __m128i x = _mm_loadu_si128((__m128i*)&a[4*i]);
        __m128i out = foo_SSE(x, c);
        out = _mm_max_epi32(out,offset);
        _mm_storeu_si128((__m128i*)&a[4*i], out);
        offset = _mm_shuffle_epi32(out, 0xff);
        offset = _mm_sub_epi32(offset,s);
    }
}

void fill_SSEv2(int *a, int n, int c) {
    __m128i offset = _mm_setzero_si128();
    __m128i s = _mm_setr_epi32(1*c, 2*c, 3*c, 4*c);
    __m128i c1 = _mm_set1_epi32(1*c);
    __m128i c2 = _mm_set1_epi32(2*c);
    for(int i=0; i<n/4; i++) {
        __m128i x1 = _mm_loadu_si128((__m128i*)&a[4*i]);
        __m128i t1;

        t1 = _mm_slli_si128(x1, 4);
        t1 = _mm_sub_epi32 (t1, c1);
        x1 = _mm_max_epi32 (x1, t1);

        t1 = _mm_slli_si128(x1, 8);
        t1 = _mm_sub_epi32 (t1, c2);
        x1 = _mm_max_epi32 (x1, t1);

        x1 = _mm_max_epi32(x1,offset);
        _mm_storeu_si128((__m128i*)&a[4*i], x1);
        offset = _mm_shuffle_epi32(x1, 0xff);
        offset = _mm_sub_epi32(offset,s);
    }
}

void fill_SSEv3(int *a, int n, int c) {
    __m128i offset = _mm_setzero_si128();
    __m128i s = _mm_setr_epi32(1*c, 2*c, 3*c, 4*c);
    __m128i c1 = _mm_set1_epi32(1*c);
    __m128i c2 = _mm_set1_epi32(2*c);
    for(int i=0; i<n/8; i++) {
        __m128i x1 = _mm_loadu_si128((__m128i*)&a[8*i]);
        __m128i x2 = _mm_loadu_si128((__m128i*)&a[8*i+4]);
        __m128i t1, t2;

        t1 = _mm_slli_si128(x1, 4);
        t1 = _mm_sub_epi32 (t1, c1);
        x1 = _mm_max_epi32 (x1, t1);

        t1 = _mm_slli_si128(x1, 8);
        t1 = _mm_sub_epi32 (t1, c2);
        x1 = _mm_max_epi32 (x1, t1);

        t2 = _mm_slli_si128(x2, 4);
        t2 = _mm_sub_epi32 (t2, c1);
        x2 = _mm_max_epi32 (x2, t2);

        t2 = _mm_slli_si128(x2, 8);
        t2 = _mm_sub_epi32 (t2, c2);
        x2 = _mm_max_epi32 (x2, t2);

        x1 = _mm_max_epi32(x1,offset);
        _mm_storeu_si128((__m128i*)&a[8*i], x1);
        offset = _mm_shuffle_epi32(x1, 0xff);
        offset = _mm_sub_epi32(offset,s);

        x2 = _mm_max_epi32(x2,offset);
        _mm_storeu_si128((__m128i*)&a[8*i+4], x2);
        offset = _mm_shuffle_epi32(x2, 0xff);
        offset = _mm_sub_epi32(offset,s);
    }
}

int main(void) {
    int n = 8, a[n], a1[n], a2[n];
    for(int i=0; i<n; i++) a[i] = i;

    /*
    a[0] = 1, a[1] = 0;
    a[2] = 2, a[3] = 0;
    a[4] = 3, a[5] = 13;
    a[6] = 4, a[7] = 0;
    */


    a[0] = 5, a[1] = 6;
    a[2] = 7, a[3] = 8;
    a[4] = 1, a[5] = 2;
    a[6] = 3, a[7] = 4;

    for(int i=0; i<n; i++) printf("%2d ", a[i]); puts("");
    for(int i=0; i<n; i++) a1[i] = a[i], a2[i] = a[i];

    int c = 1;
    foo(a1,n,c);
    foo2(a2,n,c);
    for(int i=0; i<n; i++) printf("%2d ", a1[i]); puts("");
    for(int i=0; i<n; i++) printf("%2d ", a2[i]); puts("");


    __m128i x1 = _mm_loadu_si128((__m128i*)&a[0]);
    __m128i x2 = _mm_loadu_si128((__m128i*)&a[4]);
    __m128i t, s;
    s = _mm_setr_epi32(c, 2*c, 3*c, 4*c);

    x1 = foo_SSE(x1,c);
    x2 = foo_SSE(x2,c);
    t = _mm_shuffle_epi32(x1, 0xff);
    t = _mm_sub_epi32(t,s);
    x2 = _mm_max_epi32(x2,t);

    int a3[8];
    _mm_storeu_si128((__m128i*)&a3[0], x1);
    _mm_storeu_si128((__m128i*)&a3[4], x2);
    for(int i=0; i<8; i++) printf("%2d ", a3[i]); puts("");

    int w = 8;
    n = w*1000;
    int f1[n], f2[n];
    for(int i=0; i<n; i++) f1[i] = rand()%1000;

    for(int i=0; i<n; i++) f2[i] = f1[i];
    //for(int i=0; i<n; i++) printf("%2d ", f1[i]); puts("");
    foo(f1, n, c);
    //fill(f2, n, 8, c);
    fill_SSEv3(f2, n, c);
    printf("%d\n", memcmp(f1,f2,sizeof(int)*n));
    for(int i=0; i<n; i++) {
        //    if(f1[i] != f2[i]) printf("%d\n", i);
    }
    //for(int i=0; i<n; i++) printf("%2d ", f1[i]); puts("");
    //for(int i=0; i<n; i++) printf("%2d ", f2[i]); puts("");

    int r = 200000;
    double dtime;
    dtime = -omp_get_wtime();
    for(int i=0; i<r; i++) fill_SSEv2(f2, n, c);
    //for(int i=0; i<r; i++) foo(f1, n, c);
    dtime += omp_get_wtime();
    printf("time %f\n", dtime);

    dtime = -omp_get_wtime();
    for(int i=0; i<r; i++) fill_SSEv3(f2, n, c);
    //for(int i=0; i<r; i++) foo(f1, n, c);
    dtime += omp_get_wtime();
    printf("time %f\n", dtime);

    dtime = -omp_get_wtime();
    for(int i=0; i<r; i++) foo(f1, n, c);
    //for(int i=0; i<r; i++) fill_SSEv2(f2, n, c);
    dtime += omp_get_wtime();
    printf("time %f\n", dtime);
}

根据 Paul R 的评论,我能够修复我的函数以处理带符号的整数。但是,它需要 c>=0。我确信它可以修复为 c<0.

工作
void fill_SSEv2(int *a, int n, int c) {
    __m128i offset = _mm_set1_epi32(0xf0000000);
    __m128i s = _mm_setr_epi32(1*c, 2*c, 3*c, 4*c);
    __m128i c1 = _mm_set1_epi32(1*c);
    __m128i c2 = _mm_set1_epi32(2*c);
    for(int i=0; i<n/4; i++) {
        __m128i x1 = _mm_loadu_si128((__m128i*)&a[4*i]);
        __m128i t1;

        t1 = _mm_shuffle_epi32(x1, 0x90);
        t1 = _mm_sub_epi32 (t1, c1);
        x1 = _mm_max_epi32 (x1, t1);

        t1 = _mm_shuffle_epi32(x1, 0x44);
        t1 = _mm_sub_epi32 (t1, c2);
        x1 = _mm_max_epi32 (x1, t1);

        x1 = _mm_max_epi32(x1,offset);
        _mm_storeu_si128((__m128i*)&a[4*i], x1);
        offset = _mm_shuffle_epi32(x1, 0xff);
        offset = _mm_sub_epi32(offset,s);
    }
}

此方法现在应该很容易扩展到浮点数。