在 AVX2 中重现 _mm256_sllv_epi16 和 _mm256_sllv_epi8

Reproduce _mm256_sllv_epi16 and _mm256_sllv_epi8 in AVX2

我很惊讶 _mm256_sllv_epi16/8(__m256i v1, __m256i v2)_mm256_srlv_epi16/8(__m256i v1, __m256i v2) 不在 Intel Intrinsics Guide 中,而且我没有找到任何解决方案来仅使用 AVX2 重新创建 AVX512 内在函数。

此函数将所有 16/8 位 packed int 左移 v2 中相应数据元素的计数值。

epi16 示例:

__m256i v1 = _mm256_set1_epi16(0b1111111111111111);
__m256i v2 = _mm256_setr_epi16(0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15);
v1 = _mm256_sllv_epi16(v1, v2);

则v1等于->(1111111111111111, 1111111111111110, 1111111111111100, 1111111111111000, ..............., 1000000000000000);

奇怪的是他们错过了这一点,尽管似乎许多 AVX 整数指令仅适用于 32/64 位宽度。 AVX512BW 中至少添加了 16 位(尽管我仍然不明白为什么英特尔拒绝添加 8 位移位)。

我们可以通过使用 32 位变量移位和一些掩蔽和混合来仅使用 AVX2 来模拟 16 位变量移位。

我们需要包含每个 16 位元素的 32 位元素底部的右移计数,我们可以使用 AND(对于低位元素)和立即移位来实现高位半部分。 (与标量移位不同,x86 向量移位使它们的计数饱和而不是 wrapping/masking)。

我们还需要在进行高半移位之前屏蔽掉数据的低 16 位,这样我们就不会将垃圾移位到包含 32 位元素的高 16 位一半。

__m256i _mm256_sllv_epi16(__m256i a, __m256i count) {
    const __m256i mask = _mm256_set1_epi32(0xffff0000);
    __m256i low_half = _mm256_sllv_epi32(
        a,
        _mm256_andnot_si256(mask, count)
    );
    __m256i high_half = _mm256_sllv_epi32(
        _mm256_and_si256(mask, a),
        _mm256_srli_epi32(count, 16)
    );
    return _mm256_blend_epi16(low_half, high_half, 0xaa);
}
__m256i _mm256_sllv_epi16(__m256i a, __m256i count) {
    const __m256i mask = _mm256_set1_epi32(0xffff0000); // alternating low/high words of a dword
    // shift low word of each dword: low_half = (a << (count & 0xffff)) [for each 32b element]
    // note that, because `a` isn't being masked here, we may get some "junk" bits, but these will get eliminated by the blend below
    __m256i low_half = _mm256_sllv_epi32(
        a,
        _mm256_andnot_si256(mask, count)
    );
    // shift high word of each dword: high_half = ((a & 0xffff0000) << (count >> 16)) [for each 32b element]
    __m256i high_half = _mm256_sllv_epi32(
        _mm256_and_si256(mask, a),     // make sure we shift in zeros
        _mm256_srli_epi32(count, 16)   // need the high-16 count at the bottom of a 32-bit element
    );
    // combine low and high words
    return _mm256_blend_epi16(low_half, high_half, 0xaa);
}

__m256i _mm256_srlv_epi16(__m256i a, __m256i count) {
    const __m256i mask = _mm256_set1_epi32(0x0000ffff);
    __m256i low_half = _mm256_srlv_epi32(
        _mm256_and_si256(mask, a),
        _mm256_and_si256(mask, count)
    );
    __m256i high_half = _mm256_srlv_epi32(
        a,
        _mm256_srli_epi32(count, 16)
    );
    return _mm256_blend_epi16(low_half, high_half, 0xaa);
}

GCC 8.2 将此编译为或多或少你所期望的:

_mm256_srlv_epi16(long long __vector(4), long long __vector(4)):
        vmovdqa       ymm3, YMMWORD PTR .LC0[rip]
        vpand   ymm2, ymm0, ymm3
        vpand   ymm3, ymm1, ymm3
        vpsrld  ymm1, ymm1, 16
        vpsrlvd ymm2, ymm2, ymm3
        vpsrlvd ymm0, ymm0, ymm1
        vpblendw        ymm0, ymm2, ymm0, 170
        ret
_mm256_sllv_epi16(long long __vector(4), long long __vector(4)):
        vmovdqa       ymm3, YMMWORD PTR .LC1[rip]
        vpandn  ymm2, ymm3, ymm1
        vpsrld  ymm1, ymm1, 16
        vpsllvd ymm2, ymm0, ymm2
        vpand   ymm0, ymm0, ymm3
        vpsllvd ymm0, ymm0, ymm1
        vpblendw        ymm0, ymm2, ymm0, 170
        ret

意味着仿真结果为 1x 加载 + 2x AND/ANDN + 2x 可变移位 + 1x 右移位 + 1x 混合。

Clang 6.0 做了一些有趣的事情 - 它通过使用混合消除了内存负载(和相应的屏蔽):

_mm256_sllv_epi16(long long __vector(4), long long __vector(4)):
        vpxor   xmm2, xmm2, xmm2
        vpblendw        ymm3, ymm1, ymm2, 170
        vpsllvd ymm3, ymm0, ymm3
        vpsrld  ymm1, ymm1, 16
        vpblendw        ymm0, ymm2, ymm0, 170
        vpsllvd ymm0, ymm0, ymm1
        vpblendw        ymm0, ymm3, ymm0, 170
        ret
_mm256_srlv_epi16(long long __vector(4), long long __vector(4)):
        vpxor   xmm2, xmm2, xmm2
        vpblendw        ymm3, ymm0, ymm2, 170
        vpblendw        ymm2, ymm1, ymm2, 170
        vpsrlvd ymm2, ymm3, ymm2
        vpsrld  ymm1, ymm1, 16
        vpsrlvd ymm0, ymm0, ymm1
        vpblendw        ymm0, ymm2, ymm0, 170
        ret

这导致:1x clear + 3x blend + 2x variable-shift + 1x right-shift。

我没有对哪种方法更快进行任何基准测试,但我怀疑这可能取决于 CPU,特别是 CPU 上的 PBLENDW 成本。

当然,如果您的用例受到更多限制,则可以简化上述内容,例如如果您的移位量都是常量,您可以删除使其工作所需的 masking/shifting(假设编译器不会自动为您执行此操作)。
对于左移,如果移位量是常数,则可以使用 _mm256_mullo_epi16 代替,将移位量转换为可以相乘的值,例如对于您给出的示例:

__m256i v1 = _mm256_set1_epi16(0b1111111111111111);
__m256i v2 = _mm256_setr_epi16(1<<0,1<<1,1<<2,1<<3,1<<4,1<<5,1<<6,1<<7,1<<8,1<<9,1<<10,1<<11,1<<12,1<<13,1<<14,1<<15);
v1 = _mm256_mullo_epi16(v1, v2);

更新:Peter 提到(见下面的评论)右移也可以用 _mm256_mulhi_epi16 实现(例如执行 v>>1v 乘以 1<<15 和拿高的话)。


对于 8 位变量移位,这在 AVX512 中也不存在(同样,我不知道为什么英特尔没有 8 位 SIMD 移位)。
如果 AVX512BW 可用 ,您可以使用与上述类似的技巧,使用 _mm256_sllv_epi16对于 AVX2,我想不出比第二次应用 16 位仿真更好的方法,因为您最终必须将 32 位移位给您的移位做 4 倍。 请参阅@wim 的回答,了解 AVX2 中 8 位的良好解决方案。

这是我想出的(基本上是 16 位版本在 AVX512 上被 8 位采用):

__m256i _mm256_sllv_epi8(__m256i a, __m256i count) {
    const __m256i mask = _mm256_set1_epi16(0xff00);
    __m256i low_half = _mm256_sllv_epi16(
        a,
        _mm256_andnot_si256(mask, count)
    );
    __m256i high_half = _mm256_sllv_epi16(
        _mm256_and_si256(mask, a),
        _mm256_srli_epi16(count, 8)
    );
    return _mm256_blendv_epi8(low_half, high_half, _mm256_set1_epi16(0xff00));
}

__m256i _mm256_srlv_epi8(__m256i a, __m256i count) {
    const __m256i mask = _mm256_set1_epi16(0x00ff);
    __m256i low_half = _mm256_srlv_epi16(
        _mm256_and_si256(mask, a),
        _mm256_and_si256(mask, count)
    );
    __m256i high_half = _mm256_srlv_epi16(
        a,
        _mm256_srli_epi16(count, 8)
    );
    return _mm256_blendv_epi8(low_half, high_half, _mm256_set1_epi16(0xff00));
}

(Peter Cordes 在下面提到在纯 AVX512BW(+VL) 实现中 _mm256_blendv_epi8(low_half, high_half, _mm256_set1_epi16(0xff00)) 可以替换为 _mm256_mask_blend_epi8(0xaaaaaaaa, low_half, high_half),这可能更快)

_mm256_sllv_epi8 的情况下,使用 pshufb 指令作为微小的查找 table,用乘法代替移位并不难。也可以通过乘法和其他一些指令来模拟 _mm256_srlv_epi8 的右移,请参见下面的代码。我希望至少 _mm256_sllv_epi8 比 Nyan 的 .

更有效

或多或少可以使用相同的想法来模拟 _mm256_sllv_epi16,但在这种情况下,select 正确的乘数就不那么简单了(另请参见下面的代码)。

下面的解决方案 _mm256_sllv_epi16_emu 不一定比 Nyan 的 更快或更好。 性能取决于周围的代码和使用的 CPU。 尽管如此,这里的解决方案可能是有用的,至少在较旧的计算机系统上是这样。 例如,vpsllvd 指令在 Nyan 的解决方案中使用了两次。此指令在 Intel Skylake 系统或更新版本上速度很快。 在 Intel Broadwell 或 Haswell 上,这条指令很慢,因为它解码为 3 个微操作。这里的解决方案避免了这个慢指令。

如果已知移位计数小于或等于 15,则可以跳过带有 mask_lt_15 的两行代码。

缺失的内在 _mm256_srlv_epi16 留作 reader 的练习。


/*     gcc -O3 -m64 -Wall -mavx2 -march=broadwell shift_v_epi8.c     */
#include <immintrin.h>
#include <stdio.h>
int print_epi8(__m256i  a);
int print_epi16(__m256i  a);

__m256i _mm256_sllv_epi8(__m256i a, __m256i count) {
    __m256i mask_hi        = _mm256_set1_epi32(0xFF00FF00);
    __m256i multiplier_lut = _mm256_set_epi8(0,0,0,0, 0,0,0,0, 128,64,32,16, 8,4,2,1, 0,0,0,0, 0,0,0,0, 128,64,32,16, 8,4,2,1);

    __m256i count_sat      = _mm256_min_epu8(count, _mm256_set1_epi8(8));     /* AVX shift counts are not masked. So a_i << n_i = 0 for n_i >= 8. count_sat is always less than 9.*/ 
    __m256i multiplier     = _mm256_shuffle_epi8(multiplier_lut, count_sat);  /* Select the right multiplication factor in the lookup table.                                      */
    __m256i x_lo           = _mm256_mullo_epi16(a, multiplier);               /* Unfortunately _mm256_mullo_epi8 doesn't exist. Split the 16 bit elements in a high and low part. */

    __m256i multiplier_hi  = _mm256_srli_epi16(multiplier, 8);                /* The multiplier of the high bits.                                                                 */
    __m256i a_hi           = _mm256_and_si256(a, mask_hi);                    /* Mask off the low bits.                                                                           */
    __m256i x_hi           = _mm256_mullo_epi16(a_hi, multiplier_hi);
    __m256i x              = _mm256_blendv_epi8(x_lo, x_hi, mask_hi);         /* Merge the high and low part.                                                                     */
            return x;
}


__m256i _mm256_srlv_epi8(__m256i a, __m256i count) {
    __m256i mask_hi        = _mm256_set1_epi32(0xFF00FF00);
    __m256i multiplier_lut = _mm256_set_epi8(0,0,0,0, 0,0,0,0, 1,2,4,8, 16,32,64,128, 0,0,0,0, 0,0,0,0, 1,2,4,8, 16,32,64,128);

    __m256i count_sat      = _mm256_min_epu8(count, _mm256_set1_epi8(8));     /* AVX shift counts are not masked. So a_i >> n_i = 0 for n_i >= 8. count_sat is always less than 9.*/ 
    __m256i multiplier     = _mm256_shuffle_epi8(multiplier_lut, count_sat);  /* Select the right multiplication factor in the lookup table.                                      */
    __m256i a_lo           = _mm256_andnot_si256(mask_hi, a);                 /* Mask off the high bits.                                                                          */
    __m256i multiplier_lo  = _mm256_andnot_si256(mask_hi, multiplier);        /* The multiplier of the low bits.                                                                  */
    __m256i x_lo           = _mm256_mullo_epi16(a_lo, multiplier_lo);         /* Shift left a_lo by multiplying.                                                                  */
            x_lo           = _mm256_srli_epi16(x_lo, 7);                      /* Shift right by 7 to get the low bits at the right position.                                      */

    __m256i multiplier_hi  = _mm256_and_si256(mask_hi, multiplier);           /* The multiplier of the high bits.                                                                 */
    __m256i x_hi           = _mm256_mulhi_epu16(a, multiplier_hi);            /* Variable shift left a_hi by multiplying. Use a instead of a_hi because the a_lo bits don't interfere */
            x_hi           = _mm256_slli_epi16(x_hi, 1);                      /* Shift left by 1 to get the high bits at the right position.                                      */
    __m256i x              = _mm256_blendv_epi8(x_lo, x_hi, mask_hi);         /* Merge the high and low part.                                                                     */
            return x;
}


__m256i _mm256_sllv_epi16_emu(__m256i a, __m256i count) {
    __m256i multiplier_lut = _mm256_set_epi8(0,0,0,0, 0,0,0,0, 128,64,32,16, 8,4,2,1, 0,0,0,0, 0,0,0,0, 128,64,32,16, 8,4,2,1);
    __m256i byte_shuf_mask = _mm256_set_epi8(14,14,12,12, 10,10,8,8, 6,6,4,4, 2,2,0,0, 14,14,12,12, 10,10,8,8, 6,6,4,4, 2,2,0,0);

    __m256i mask_lt_15     = _mm256_cmpgt_epi16(_mm256_set1_epi16(16), count);
            a              = _mm256_and_si256(mask_lt_15, a);                    /* Set a to zero if count > 15.                                                                      */
            count          = _mm256_shuffle_epi8(count, byte_shuf_mask);         /* Duplicate bytes from the even postions to bytes at the even and odd positions.                    */
            count          = _mm256_sub_epi8(count,_mm256_set1_epi16(0x0800));   /* Subtract 8 at the even byte positions. Note that the vpshufb instruction selects a zero byte if the shuffle control mask is negative.     */
    __m256i multiplier     = _mm256_shuffle_epi8(multiplier_lut, count);         /* Select the right multiplication factor in the lookup table. Within the 16 bit elements, only the upper byte or the lower byte is nonzero. */
    __m256i x              = _mm256_mullo_epi16(a, multiplier);                  
            return x;
}


int main(){

    printf("Emulating _mm256_sllv_epi8:\n");
    __m256i a     = _mm256_set_epi8(32,31,30,29, 28,27,26,25, 24,23,22,21, 20,19,18,17, 16,15,14,13, 12,11,10,9, 8,7,6,5, 4,3,2,1);
    __m256i count = _mm256_set_epi8(7,6,5,4, 3,2,1,0,  11,10,9,8, 7,6,5,4, 3,2,1,0,  11,10,9,8, 7,6,5,4, 3,2,1,0);
    __m256i x     = _mm256_sllv_epi8(a, count);
    printf("a     = \n"); print_epi8(a    );
    printf("count = \n"); print_epi8(count);
    printf("x     = \n"); print_epi8(x    );
    printf("\n\n"); 


    printf("Emulating _mm256_srlv_epi8:\n");
            a     = _mm256_set_epi8(223,224,225,226, 227,228,229,230, 231,232,233,234, 235,236,237,238, 239,240,241,242, 243,244,245,246, 247,248,249,250, 251,252,253,254);
            count = _mm256_set_epi8(7,6,5,4, 3,2,1,0,  11,10,9,8, 7,6,5,4, 3,2,1,0,  11,10,9,8, 7,6,5,4, 3,2,1,0);
            x     = _mm256_srlv_epi8(a, count);
    printf("a     = \n"); print_epi8(a    );
    printf("count = \n"); print_epi8(count);
    printf("x     = \n"); print_epi8(x    );
    printf("\n\n"); 



    printf("Emulating _mm256_sllv_epi16:\n");
            a     = _mm256_set_epi16(1601,1501,1401,1301, 1200,1100,1000,900, 800,700,600,500, 400,300,200,100);
            count = _mm256_set_epi16(17,16,15,13,  11,10,9,8, 7,6,5,4, 3,2,1,0);
            x     = _mm256_sllv_epi16_emu(a, count);
    printf("a     = \n"); print_epi16(a    );
    printf("count = \n"); print_epi16(count);
    printf("x     = \n"); print_epi16(x    );
    printf("\n\n"); 

    return 0;
}


int print_epi8(__m256i  a){
  char v[32];
  int i;
  _mm256_storeu_si256((__m256i *)v,a);
  for (i = 0; i<32; i++) printf("%4hhu",v[i]);
  printf("\n");
  return 0;
}

int print_epi16(__m256i  a){
  unsigned short int  v[16];
  int i;
  _mm256_storeu_si256((__m256i *)v,a);
  for (i = 0; i<16; i++) printf("%6hu",v[i]);
  printf("\n");
  return 0;
}

输出为:

Emulating _mm256_sllv_epi8:
a     = 
   1   2   3   4   5   6   7   8   9  10  11  12  13  14  15  16  17  18  19  20  21  22  23  24  25  26  27  28  29  30  31  32
count = 
   0   1   2   3   4   5   6   7   8   9  10  11   0   1   2   3   4   5   6   7   8   9  10  11   0   1   2   3   4   5   6   7
x     = 
   1   4  12  32  80 192 192   0   0   0   0   0  13  28  60 128  16  64 192   0   0   0   0   0  25  52 108 224 208 192 192   0


Emulating _mm256_srlv_epi8:
a     = 
 254 253 252 251 250 249 248 247 246 245 244 243 242 241 240 239 238 237 236 235 234 233 232 231 230 229 228 227 226 225 224 223
count = 
   0   1   2   3   4   5   6   7   8   9  10  11   0   1   2   3   4   5   6   7   8   9  10  11   0   1   2   3   4   5   6   7
x     = 
 254 126  63  31  15   7   3   1   0   0   0   0 242 120  60  29  14   7   3   1   0   0   0   0 230 114  57  28  14   7   3   1


Emulating _mm256_sllv_epi16:
a     = 
   100   200   300   400   500   600   700   800   900  1000  1100  1200  1301  1401  1501  1601
count = 
     0     1     2     3     4     5     6     7     8     9    10    11    13    15    16    17
x     = 
   100   400  1200  3200  8000 19200 44800 36864 33792 53248 12288 32768 40960 32768     0     0

确实缺少一些 AVX2 指令。 但是,请注意,通过模拟 'missing' AVX2 指令来填补这些空白并不总是一个好主意。有时是 以避免这些模拟指令的方式重新设计代码会更有效。例如,通过使用更宽的向量 元素(_epi32 而不是 _epi16),具有原生支持。