将字段中的位扩展到掩码中所有(重叠+相邻)设置位的最快方法?

Fastest way to expand bits in a field to all (overlapping + adjacent) set bits in a mask?

假设我有 2 个名为 IN 和 MASK 的二进制输入。实际字段大小可能是 32 到 256 位,具体取决于用于完成任务的指令集。两个输入每次调用都会改变。

Inputs:
IN   = ...1100010010010100...
MASK = ...0001111010111011...
Output:
OUT  = ...0001111010111000...

编辑:来自一些评论讨论的另一个例子

IN   = ...11111110011010110...
MASK = ...01011011001111110...
Output:
OUT  = ...01011011001111110...

我想获取 IN 的 1 位所在的 MASK 的连续相邻 1 位。 (这种操作有通用术语吗?也许我没有正确地表述我的搜索。)我正在尝试找到一种更快的方法来执行此操作。我愿意使用任何 x86 或 x86 SIMD 扩展,这些扩展可以在最少 cpu 周期内完成。更广泛的数据类型 SIMD 是首选,因为它允许我一次处理更多数据。

我想出的最好的朴素解决方案是以下伪代码,它手动左移直到没有更多匹配位,然后重复右移:

// (using the variables above)
testL = testR = OUT = (IN & MASK);

LoopL:
testL = (testL << 1) & MASK;
if (testL != 0) {
    OUT = OUT | testL;
    goto LoopL;
}

LoopR:
testR = (testR >> 1) & MASK;
if (testR != 0) {
    OUT = OUT | testR;
    goto LoopR;
}

return OUT;

下面的方法只需要一个循环,迭代次数等于 'groups' 找到的次数。 我不知道它是否会比你的方法更有效率;每次迭代有 6 arith/bitwise 次操作。

在伪代码中(C-like):

OUT = 0;
a = MASK;
while (a)
{
    e = a & ~(a + (a & (-a)));
    if (e & IN) OUT |= e;
    a ^= e;
}

这是它的工作原理,一步一步,使用 11010111 作为示例掩码:

OUT = 0

a = MASK        11010111
c = a & (-a)    00000001   keeps rightmost one only
d = a + c       11011000   clears rightmost group (and set the bit to its immediate left)
e = a & ~d      00000111   keeps rightmost group only

if (e & IN) OUT |= e;      adds group to OUT

a = a ^ e       11010000   clears rightmost group, so we can proceed with the next group
c = a & (-a)    00010000
d = a + c       11100000
e = a & ~d      00010000

if (e & IN) OUT |= e;

a = a ^ e       11000000
c = a & (-a)    01000000
d = a + c       00000000   (ignoring carry when adding)
e = a & ~d      11000000

if (e & IN) OUT |= e;

a = a ^ e       00000000   done

正如@PeterCordes 所指出的,一些操作可以使用 x86 BMI1 指令进行优化:

这种方法适用于不支持按位反转的处理器架构。在确实有专用指令来反转整数中位顺序的体系结构上, 效率更高。

我猜@fuz 的评论是对的。 以下示例显示了 SSE 和 AVX2 代码的工作原理。 算法以IN_reduced = IN & MASK开头,因为我们不感兴趣 在 IN 位中 MASK0 的位置。

IN                                  = . . . 0 0 0 0 . . . . p q r s . . .
MASK                                = . . 0 1 1 1 1 0 . . 0 1 1 1 1 0 . . 
IN_reduced = IN & MASK              = . . 0 0 0 0 0 0 . . 0 p q r s 0 . .

如果任何 p q r s 位是 1,则 IN_reduced + MASK 有一个进位位 1X 位置,它位于 请求的连续位。

MASK                                = . . 0 1 1 1 1 0 . . 0 1 1 1 1 0 . . 
IN_reduced                          = . . 0 0 0 0 0 0 . . 0 p q r s 0 . .
IN_reduced + MASK                   = . . 0 1 1 1 1 . . . 1 . . . . . .
                                                          X
(IN_reduced + MASK) >>1             = . . . 0 1 1 1 1 . . . 1 . . . . . .

对于 >> 1,此进位位 1 被移至与位 p 相同的列 (连续位的第一位)。 现在,(IN_reduced + MASK) >>1 实际上是 IN_reducedMASK 的平均值。 为了避免可能的加法溢出,我们使用以下 平均值:avg(a, b) = (a & b) + ((a ^ b) >> 1)(参见@Harold 的评论, 另见 here and here。) 使用 average = avg(IN_reduced, MASK) 我们得到

MASK                                = . . 0 1 1 1 1 0 . . 0 1 1 1 1 0 . . 
IN_reduced                          = . . 0 0 0 0 0 0 . . 0 p q r s 0 . .
average                             = . . . 0 1 1 1 1 . . . 1 . . . . . .
MASK >> 1                           = . . . 0 1 1 1 1 0 . . 0 1 1 1 1 0 .  
leading_bits = (~(MASK>>1))&average = . . . 0 0 0 0 0 . . . 1 0 0 0 0 . .  

我们可以用 leading_bits = (~(MASK>>1) ) & average 因为 MASK>>1 在这些位置为零 的进位 我们感兴趣的。

对于正常加法,进位从右向左传播。这里我们使用一个 反向加法:从左到右有一个进位。 反向添加 MASKleading_bitsrev_added = bit_swap(bit_swap(MASK) + bit_swap(leading_bits)), 这会将位归零 想要的职位。 使用 OUT = (~rev_added) & MASK 我们得到结果。

MASK                                = . . 0 1 1 1 1 0 . . 0 1 1 1 1 0 . . 
leading_bits                        = . . . 0 0 0 0 0 . . . 1 0 0 0 0 . .  
rev_added (MASK,leading_bits)       = . . . 1 1 1 1 0 . . . 0 0 0 0 1 . .
OUT = ~rev_added & MASK             = . . 0 0 0 0 0 0 . . . 1 1 1 1 0 . .

算法没有经过彻底测试,但输出看起来不错。


下面的代码块包含两个单独的代码: 上半部分是SSE代码, 下半部分是AVX2代码。 (为了避免 用两个大代码块使答案膨胀太多。) SSE 算法适用于 2 个 64 位元素,而 AVX2 版本适用于 4 个 64 位元素。

使用 gcc 9.1,算法 compiles to about 29 instructions, 除了 4 vmovdqa-s 用于加载一些常量,这很可能 在现实世界的应用程序中被提升到循环之外(内联之后)。 这 29 条指令很好地混合了执行的 9 次随机播放 (vpshufb) 在 Intel Skylake 的端口 5 (p5) 上,以及许多其他经常可能出现的指令 在 p0、p1 或 p5 上执行。

因此,每个周期执行大约 3 条指令是可能的。 在那种情况下,吞吐量大约是 1 个函数调用(内联) 每 10 个周期。在 AVX2 的情况下,这意味着每个 4 uint64_t OUT 个结果 大约 10 个周期。

请注意,性能与数据无关(!),这是一个很好的 我认为这个答案的好处。该解决方案是无分支的,无环的,并且 不能遭受失败的分支预测。


/*  gcc -O3 -m64 -Wall -march=skylake select_bits.c    */
#include <immintrin.h>
#include <stdio.h>
#include <stdint.h>

int print_sse_128_bin(__m128i x);
__m128i bit_128_k(unsigned int k);
__m128i mm_bitreverse_epi64(__m128i x);
__m128i mm_revadd_epi64(__m128i x, __m128i y);


/* Select specific pieces of contiguous bits from `MASK` based on selector `IN`  */
__m128i mm_select_bits_epi64(__m128i IN, __m128i MASK){
    __m128i IN_reduced   = _mm_and_si128(IN, MASK);
    /* Compute the average of IN_reduced and MASK with avg(a,b)=(a&b)+((a^b)>>1)  */
    /* (IN_reduced & MASK) + ((IN_reduced ^ MASK) >>1) =                          */
    /* ((IN & MASK) & MASK) + ((IN_reduced ^ MASK) >>1) =                         */
    /* IN_reduced + ((IN_reduced ^ MASK) >>1)                                     */
    __m128i tmp          = _mm_xor_si128(IN_reduced, MASK);
    __m128i tmp_div2     = _mm_srli_epi64(tmp, 1);
    __m128i average      = _mm_add_epi64(IN_reduced, tmp_div2);   /* average is the average */
    __m128i MASK_div2    = _mm_srli_epi64(MASK, 1);
    __m128i leading_bits = _mm_andnot_si128(MASK_div2, average);
    __m128i rev_added    = mm_revadd_epi64(MASK, leading_bits);
    __m128i OUT          = _mm_andnot_si128(rev_added, MASK);
    /* Uncomment the next lines to check the arithmetic */ /*   
    printf("IN           ");print_sse_128_bin(IN           );       
    printf("MASK         ");print_sse_128_bin(MASK         ); 
    printf("IN_reduced   ");print_sse_128_bin(IN_reduced   );       
    printf("tmp          ");print_sse_128_bin(tmp          );       
    printf("tmp_div2     ");print_sse_128_bin(tmp_div2     );       
    printf("average      ");print_sse_128_bin(average      );       
    printf("MASK_div2    ");print_sse_128_bin(MASK_div2    );       
    printf("leading_bits ");print_sse_128_bin(leading_bits );       
    printf("rev_added    ");print_sse_128_bin(rev_added    );       
    printf("OUT          ");print_sse_128_bin(OUT          );       
    printf("\n");*/
    return OUT;       
}


int main(){
    __m128i IN   = _mm_set_epi64x(0b11111110011010110, 0b1100010010010100);
    __m128i MASK = _mm_set_epi64x(0b01011011001111110, 0b0001111010111011);
    __m128i OUT;    

    printf("Example 1 \n");
    OUT = mm_select_bits_epi64(IN, MASK);
    printf("IN           ");print_sse_128_bin(IN);
    printf("MASK         ");print_sse_128_bin(MASK);
    printf("OUT          ");print_sse_128_bin(OUT);
    printf("\n\n");

                      /*  0b7654321076543210765432107654321076543210765432107654321076543210  */
    IN   = _mm_set_epi64x(0b1000001001001010000010000000100000010000000000100000000111100011, 
                          0b11111110011010111);
    MASK = _mm_set_epi64x(0b1110011110101110111111000000000111011111101101111100011111000001, 
                          0b01011011001111111);

    printf("Example 2 \n");
    OUT = mm_select_bits_epi64(IN, MASK);
    printf("IN           ");print_sse_128_bin(IN);
    printf("MASK         ");print_sse_128_bin(MASK);
    printf("OUT          ");print_sse_128_bin(OUT);
    printf("\n\n");

    return 0;
}


int print_sse_128_bin(__m128i x){
    for (int i = 127; i >= 0; i--){
        printf("%1u", _mm_testnzc_si128(bit_128_k(i), x));
        if (((i & 7) == 0) && (i > 0)) printf(" ");
    }
    printf("\n");
    return 0;
}


/* From my answer here  adapted to 128-bit */
inline __m128i bit_128_k(unsigned int k){
  __m128i  indices     = _mm_set_epi32(96, 64, 32, 0);
  __m128i  one         = _mm_set1_epi32(1);

  __m128i  kvec        = _mm_set1_epi32(k);  
  __m128i  shiftcounts = _mm_sub_epi32(kvec, indices);
  __m128i  kbit        = _mm_sllv_epi32(one, shiftcounts);   
  return kbit;                             
}


/* Copied from Harold's answer          */
/* Adapted to epi64 and __m128i: bit reverse two 64 bit elements                    */
inline __m128i mm_bitreverse_epi64(__m128i x){
    __m128i shufbytes = _mm_setr_epi8(7, 6, 5, 4, 3, 2, 1, 0, 15, 14, 13, 12, 11, 10, 9, 8); 
    __m128i luthigh = _mm_setr_epi8(0, 8, 4, 12, 2, 10, 6, 14, 1, 9, 5, 13, 3, 11, 7, 15);
    __m128i lutlow = _mm_slli_epi16(luthigh, 4);
    __m128i lowmask = _mm_set1_epi8(15);
    __m128i rbytes = _mm_shuffle_epi8(x, shufbytes);
    __m128i high = _mm_shuffle_epi8(lutlow, _mm_and_si128(rbytes, lowmask));
    __m128i low = _mm_shuffle_epi8(luthigh, _mm_and_si128(_mm_srli_epi16(rbytes, 4), lowmask));
    return _mm_or_si128(low, high);
}


/* Add in the reverse direction: With a carry from left to */
/* right, instead of right to left                         */
inline __m128i mm_revadd_epi64(__m128i x, __m128i y){
    x = mm_bitreverse_epi64(x);
    y = mm_bitreverse_epi64(y);
    __m128i sum = _mm_add_epi64(x, y);
    return mm_bitreverse_epi64(sum);
}
/* End of SSE code */


/************* AVX2 code starts here ********************************************/

/*  gcc -O3 -m64 -Wall -march=skylake select_bits256.c    */
#include <immintrin.h>
#include <stdio.h>
#include <stdint.h>

int print_avx_256_bin(__m256i x);
__m256i bit_256_k(unsigned int k);
__m256i mm256_bitreverse_epi64(__m256i x);
__m256i mm256_revadd_epi64(__m256i x, __m256i y);


/* Select specific pieces of contiguous bits from `MASK` based on selector `IN`  */
__m256i mm256_select_bits_epi64(__m256i IN, __m256i MASK){
    __m256i IN_reduced   = _mm256_and_si256(IN, MASK);
    /* Compute the average of IN_reduced and MASK with avg(a,b)=(a&b)+((a^b)>>1)  */
    /* (IN_reduced & MASK) + ((IN_reduced ^ MASK) >>1) =                          */
    /* ((IN & MASK) & MASK) + ((IN_reduced ^ MASK) >>1) =                         */
    /* IN_reduced + ((IN_reduced ^ MASK) >>1)                                     */
    __m256i tmp          = _mm256_xor_si256(IN_reduced, MASK);
    __m256i tmp_div2     = _mm256_srli_epi64(tmp, 1);
    __m256i average      = _mm256_add_epi64(IN_reduced, tmp_div2);   /* average is the average */
    __m256i MASK_div2    = _mm256_srli_epi64(MASK, 1);
    __m256i leading_bits = _mm256_andnot_si256(MASK_div2, average);
    __m256i rev_added    = mm256_revadd_epi64(MASK, leading_bits);
    __m256i OUT          = _mm256_andnot_si256(rev_added, MASK);
    /* Uncomment the next lines to check the arithmetic */ /*   
    printf("IN           ");print_avx_256_bin(IN           );       
    printf("MASK         ");print_avx_256_bin(MASK         ); 
    printf("IN_reduced   ");print_avx_256_bin(IN_reduced   );       
    printf("tmp          ");print_avx_256_bin(tmp          );       
    printf("tmp_div2     ");print_avx_256_bin(tmp_div2     );       
    printf("average      ");print_avx_256_bin(average      );       
    printf("MASK_div2    ");print_avx_256_bin(MASK_div2    );       
    printf("leading_bits ");print_avx_256_bin(leading_bits );       
    printf("rev_added    ");print_avx_256_bin(rev_added    );       
    printf("OUT          ");print_avx_256_bin(OUT          );       
    printf("\n");*/
    return OUT;       
}


int main(){
    __m256i IN   = _mm256_set_epi64x(0b11111110011010110, 
                                     0b1100010010010100,
                                     0b1000001001001010000010000000100000010000000000100000000111100011, 
                                     0b11111110011010111
    );
    __m256i MASK = _mm256_set_epi64x(0b01011011001111110, 
                                     0b0001111010111011,
                                     0b1110011110101110111111000000000111011111101101111100011111000001, 
                                     0b01011011001111111);
    __m256i OUT;    

    printf("Example \n");
    OUT = mm256_select_bits_epi64(IN, MASK);
    printf("IN           ");print_avx_256_bin(IN);
    printf("MASK         ");print_avx_256_bin(MASK);
    printf("OUT          ");print_avx_256_bin(OUT);
    printf("\n");

    return 0;
}


int print_avx_256_bin(__m256i x){
    for (int i=255;i>=0;i--){
        printf("%1u",_mm256_testnzc_si256(bit_256_k(i),x));
        if (((i&7) ==0)&&(i>0)) printf(" ");
    }
    printf("\n");
    return 0;
}


/* From my answer here  */
inline __m256i bit_256_k(unsigned int k){
  __m256i  indices     = _mm256_set_epi32(224,192,160,128,96,64,32,0);
  __m256i  one         = _mm256_set1_epi32(1);

  __m256i  kvec        = _mm256_set1_epi32(k);  
  __m256i  shiftcounts = _mm256_sub_epi32(kvec, indices);
  __m256i  kbit        = _mm256_sllv_epi32(one, shiftcounts);   
  return kbit;                             
}


/* Copied from Harold's answer          */
/* Adapted to epi64: bit reverse four 64 bit elements                    */
inline __m256i mm256_bitreverse_epi64(__m256i x){
    __m256i shufbytes = _mm256_setr_epi8(7, 6, 5, 4, 3, 2, 1, 0, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0, 15, 14, 13, 12, 11, 10, 9, 8);
    __m256i luthigh = _mm256_setr_epi8(0, 8, 4, 12, 2, 10, 6, 14, 1, 9, 5, 13, 3, 11, 7, 15, 0, 8, 4, 12, 2, 10, 6, 14, 1, 9, 5, 13, 3, 11, 7, 15);
    __m256i lutlow = _mm256_slli_epi16(luthigh, 4);
    __m256i lowmask = _mm256_set1_epi8(15);
    __m256i rbytes = _mm256_shuffle_epi8(x, shufbytes);
    __m256i high = _mm256_shuffle_epi8(lutlow, _mm256_and_si256(rbytes, lowmask));
    __m256i low = _mm256_shuffle_epi8(luthigh, _mm256_and_si256(_mm256_srli_epi16(rbytes, 4), lowmask));
    return _mm256_or_si256(low, high);
}


/* Add in the reverse direction: With a carry from left to */
/* right, instead of right to left                         */
inline __m256i mm256_revadd_epi64(__m256i x, __m256i y){
    x = mm256_bitreverse_epi64(x);
    y = mm256_bitreverse_epi64(y);
    __m256i sum = _mm256_add_epi64(x, y);
    return mm256_bitreverse_epi64(sum);
}


带有未注释调试部分的 SSE 代码输出:

Example 1 
IN           00000000 00000000 00000000 00000000 00000000 00000001 11111100 11010110 00000000 00000000 00000000 00000000 00000000 00000000 11000100 10010100
MASK         00000000 00000000 00000000 00000000 00000000 00000000 10110110 01111110 00000000 00000000 00000000 00000000 00000000 00000000 00011110 10111011
IN_reduced   00000000 00000000 00000000 00000000 00000000 00000000 10110100 01010110 00000000 00000000 00000000 00000000 00000000 00000000 00000100 10010000
tmp          00000000 00000000 00000000 00000000 00000000 00000000 00000010 00101000 00000000 00000000 00000000 00000000 00000000 00000000 00011010 00101011
tmp_div2     00000000 00000000 00000000 00000000 00000000 00000000 00000001 00010100 00000000 00000000 00000000 00000000 00000000 00000000 00001101 00010101
average      00000000 00000000 00000000 00000000 00000000 00000000 10110101 01101010 00000000 00000000 00000000 00000000 00000000 00000000 00010001 10100101
MASK_div2    00000000 00000000 00000000 00000000 00000000 00000000 01011011 00111111 00000000 00000000 00000000 00000000 00000000 00000000 00001111 01011101
leading_bits 00000000 00000000 00000000 00000000 00000000 00000000 10100100 01000000 00000000 00000000 00000000 00000000 00000000 00000000 00010000 10100000
rev_added    00000000 00000000 00000000 00000000 00000000 00000000 01001001 00000001 00000000 00000000 00000000 00000000 00000000 00000000 00000001 01000111
OUT          00000000 00000000 00000000 00000000 00000000 00000000 10110110 01111110 00000000 00000000 00000000 00000000 00000000 00000000 00011110 10111000

IN           00000000 00000000 00000000 00000000 00000000 00000001 11111100 11010110 00000000 00000000 00000000 00000000 00000000 00000000 11000100 10010100
MASK         00000000 00000000 00000000 00000000 00000000 00000000 10110110 01111110 00000000 00000000 00000000 00000000 00000000 00000000 00011110 10111011
OUT          00000000 00000000 00000000 00000000 00000000 00000000 10110110 01111110 00000000 00000000 00000000 00000000 00000000 00000000 00011110 10111000


Example 2 
IN           10000010 01001010 00001000 00001000 00010000 00000010 00000001 11100011 00000000 00000000 00000000 00000000 00000000 00000001 11111100 11010111
MASK         11100111 10101110 11111100 00000001 11011111 10110111 11000111 11000001 00000000 00000000 00000000 00000000 00000000 00000000 10110110 01111111
IN_reduced   10000010 00001010 00001000 00000000 00010000 00000010 00000001 11000001 00000000 00000000 00000000 00000000 00000000 00000000 10110100 01010111
tmp          01100101 10100100 11110100 00000001 11001111 10110101 11000110 00000000 00000000 00000000 00000000 00000000 00000000 00000000 00000010 00101000
tmp_div2     00110010 11010010 01111010 00000000 11100111 11011010 11100011 00000000 00000000 00000000 00000000 00000000 00000000 00000000 00000001 00010100
average      10110100 11011100 10000010 00000000 11110111 11011100 11100100 11000001 00000000 00000000 00000000 00000000 00000000 00000000 10110101 01101011
MASK_div2    01110011 11010111 01111110 00000000 11101111 11011011 11100011 11100000 00000000 00000000 00000000 00000000 00000000 00000000 01011011 00111111
leading_bits 10000100 00001000 10000000 00000000 00010000 00000100 00000100 00000001 00000000 00000000 00000000 00000000 00000000 00000000 10100100 01000000
rev_added    00010000 01100001 00000010 00000001 11000000 01110000 00100000 00100000 00000000 00000000 00000000 00000000 00000000 00000000 01001001 00000000
OUT          11100111 10001110 11111100 00000000 00011111 10000111 11000111 11000001 00000000 00000000 00000000 00000000 00000000 00000000 10110110 01111111

IN           10000010 01001010 00001000 00001000 00010000 00000010 00000001 11100011 00000000 00000000 00000000 00000000 00000000 00000001 11111100 11010111
MASK         11100111 10101110 11111100 00000001 11011111 10110111 11000111 11000001 00000000 00000000 00000000 00000000 00000000 00000000 10110110 01111111
OUT          11100111 10001110 11111100 00000000 00011111 10000111 11000111 11000001 00000000 00000000 00000000 00000000 00000000 00000000 10110110 01111111