AVX2:AVX 寄存器中 8 位元素上的 BitScanReverse 或 CountLeadingZeros

AVX2: BitScanReverse or CountLeadingZeros on 8 bit elements in AVX register

我想提取具有 8 位元素的 256 位 AVX 寄存器中最高设置位的索引。我既找不到 bsr 也找不到 clz 实现。

对于具有 32 位元素的 clz,有带浮点数转换的 bithack,但这对于 8 位元素可能是不可能的。

目前,我正在研究一个解决方案,我会逐位检查,稍后我会添加,但我想知道是否有更快的方法来做到这一点。

AVX512 解决方案,没试过,但我认为这个想法应该可行:

// Form four 32-bit vectors with high bytes from the source
__m256i a0 = _mm256_or_si256(_mm256_slli_si256(a, 3),  _mm256_set1_epi32(0x00FF'FFFF));
__m256i a1 = _mm256_or_si256(_mm256_slli_si256(a, 2),  _mm256_set1_epi32(0x00FF'FFFF));
__m256i a2 = _mm256_or_si256(_mm256_slli_si256(a, 1),  _mm256_set1_epi32(0x00FF'FFFF));
__m256i a3 = _mm256_or_si256(                  a,      _mm256_set1_epi32(0x00FF'FFFF));
// Count lead bits and shift according to bit position
__m256i c0 =                   _mm256_lzcnt_epi32(a0);
__m256i c1 = _mm256_slli_si256(_mm256_lzcnt_epi32(a1), 1);
__m256i c2 = _mm256_slli_si256(_mm256_lzcnt_epi32(a2), 2);
__m256i c3 = _mm256_slli_si256(_mm256_lzcnt_epi32(a3), 3);
//Gather the result
__m256i r  = _mm256_or_si256(_mm256_or_si256(c0,c1),_mm256_or_si256(c2,c3));

不确定是否比一个一个检查快

给定目标 AVX 寄存器 _a,这有效。如果有什么需要优化的,请告诉我(或直接编辑)。

__m256i _a;
__m256i _old_mask = _mm256_set1_epi8(-1);
__m256i _extract_bitmask, _extracted_bit, _mask;

for (int i = 7; i >= 0; i--)
{
    // bitmask to extract bit from _a at position i
    _extract_bitmask = _mm256_set1_epi8(1 << i);

    // the extracted bit
    _extracted_bit = _mm256_and_si256(_a, _extract_bitmask);
    
    // check if bit at position i is set and if was not set before
    _mask = _mm256_cmpeq_epi8(_extract_bitmask, _extracted_bit);
    _mask = _mm256_and_si256(_mask, _old_mask);
    
    // update mask
    _old_mask = _mm256_andnot_si256(_mask, _old_mask);

    // update result according to _mask
    _result = _mm256_blendv_epi8(_result, _mm256_set1_epi8(i), _mask);
}

这是一个基于 vpshufb 的解决方案。这个想法是将输入分成两半,对两半进行查找并合并结果:

__m256i clz_epu8(__m256i values)
{
    // extract upper nibble:
    __m256i hi = _mm256_and_si256(_mm256_srli_epi16(values, 4), _mm256_set1_epi8(0xf));
    // this sets the highest bit for values >= 0x10 and otherwise keeps the lower nibble unmodified:
    __m256i lo = _mm256_adds_epu8(values, _mm256_set1_epi8(0x70));

    // lookup tables for count-leading-zeros (replace this by _mm256_setr_epi8, if this does not get optimized away)
    // ideally, this should compile to vbroadcastf128 ...
    const __m256i lookup_hi = _mm256_broadcastsi128_si256(_mm_setr_epi8(0, 3, 2, 2, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0));
    const __m256i lookup_lo = _mm256_broadcastsi128_si256(_mm_setr_epi8(8, 7, 6, 6, 5, 5, 5, 5, 4, 4, 4, 4, 4, 4, 4, 4));

    // look up each half
    __m256i clz_hi = _mm256_shuffle_epi8(lookup_hi, hi);
    __m256i clz_lo = _mm256_shuffle_epi8(lookup_lo, lo);

    // combine results (addition or xor would work as well)
    return _mm256_or_si256(clz_hi, clz_lo);
}

godbolt-link 粗略测试:https://godbolt.org/z/MYq74Wxdh

通常 _mm_shuffle_epi8 需要屏蔽来隔离每个半字节以将其用作 LUT,因为设置高位会使输出元素为 0。但是对于 CLZ,如果设置了高位,则结果正确是因为整个字节是 0,我们组合的方式意味着 lut_lo 可以生成它。

__m128i ssse3_lzcnt_epi8(__m128i v) {
    const __m128i lut_lo = _mm_set_epi8(4, 4, 4, 4, 4, 4, 4, 4, 5, 5, 5, 5, 6, 6, 7, 8);
    const __m128i lut_hi = _mm_set_epi8(0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 3, 8);
    __m128i t;

    t = _mm_and_si128(_mm_srli_epi16(v, 4), _mm_set1_epi8(0x0F));
    t = _mm_shuffle_epi8(lut_hi, t);
    v = _mm_shuffle_epi8(lut_lo, v);
    v = _mm_min_epu8(v, t);
    return v;
}

与使用 _mm_adds_epu8 并将 LUT 结果与 or 组合相比,这节省了一条指令。