计算 AVX2 向量中每个元素的前导零位,模拟 _mm256_lzcnt_epi32

Count leading zero bits for each element in AVX2 vector, emulate _mm256_lzcnt_epi32

对于 AVX512,有内在的 _mm256_lzcnt_epi32,其中 returns 一个向量,对于 8 个 32 位元素中的每一个,包含输入向量元素中前导零位的数量.

是否有仅使用 AVX 和 AVX2 指令实现此功能的有效方法?

目前我正在使用一个循环来提取每个元素并应用 _lzcnt_u32 函数。


相关:要对一个大位图进行位扫描,请参阅 ,它使用 pmovmskb -> 位扫描来查找要对哪个字节进行标量位扫描。

这个问题是关于在您实际要使用所有 8 个结果时对 8 个单独的 32 位元素执行 8 个单独的 lzcnts,而不仅仅是 select 一个。

@aqrit 的回答看起来更像是对 FP bithacks 的巧妙使用。我下面的回答是基于我首先寻找一个旧的并且针对标量的 bithack,所以它没有试图避免 double (它比 int32 宽,因此是 SIMD 的问题).

它使用硬件签名 int->float 转换和饱和整数减法来处理设置的 MSB(负浮点数),而不是将位填充到尾数中以用于手动 uint->double .如果您可以将 MXCSR 设置为在其中的很多 _mm256_lzcnt_epi32 中向下取整,效率会更高。


https://graphics.stanford.edu/~seander/bithacks.html#IntegerLogIEEE64Float suggests stuffing integers into the mantissa of a large double, then subtracting to get the FPU hardware to get a normalized double. (I think this bit of magic is doing uint32_t -> double, with the technique @Mysticial explains in (适用于 uint64_t 最多 252-1)

然后获取 double 的指数位并取消偏差。

我认为整数 log2 与 lzcnt 是一样的,但在 2 的幂处可能存在 off-by-1。

Standford Graphics bithack 页面列出了您可以使用的其他无分支 bithack,它们可能仍然优于 8x 标量 lzcnt

如果您知道您的数字总是很小(例如小于 2^23),您可以使用 float 来做到这一点并避免拆分和混合。

  int v; // 32-bit integer to find the log base 2 of
  int r; // result of log_2(v) goes here
  union { unsigned int u[2]; double d; } t; // temp

  t.u[__FLOAT_WORD_ORDER==LITTLE_ENDIAN] = 0x43300000;
  t.u[__FLOAT_WORD_ORDER!=LITTLE_ENDIAN] = v;
  t.d -= 4503599627370496.0;
  r = (t.u[__FLOAT_WORD_ORDER==LITTLE_ENDIAN] >> 20) - 0x3FF;

The code above loads a 64-bit (IEEE-754 floating-point) double with a 32-bit integer (with no paddding bits) by storing the integer in the mantissa while the exponent is set to 252. From this newly minted double, 252 (expressed as a double) is subtracted, which sets the resulting exponent to the log base 2 of the input value, v. All that is left is shifting the exponent bits into position (20 bits right) and subtracting the bias, 0x3FF (which is 1023 decimal).

要使用 AVX2 执行此操作,请将 odd/even 与 set1_epi32(0x43300000)_mm256_castps_pd 混合并混合,以获得 __m256d 减去 _mm256_castpd_si256 并将 low/high 两半移动/混合到位,然后掩码以获得指数。

使用 AVX2 对 FP 位模式进行整数运算非常有效,在对 FP 数学指令的输出进行整数移位时,旁路延迟只需 1 个额外延迟周期。

(TODO:用 C++ 内在函数编写它,欢迎编辑,或者其他人可以 post 它作为答案。)


我不确定您是否可以通过 int -> double conversion 然后读取指数字段来做任何事情。负数没有前导零,正数给出一个取决于大小的指数。

如果您确实想要那样,您将一次进入一个 128 位通道,洗牌以提供 xmm -> ymm packed int32_t -> packed double 转换。

float 表示指数格式的数字,因此 int->FP 转换为我们提供了指数字段中编码的最高设置位的位置。

我们希望 int->float 的幅度四舍五入 down(将值截断为 0),而不是默认的最近舍入。这可能会使 0x3FFFFFFF 看起来像 0x40000000。如果您在不进行任何 FP 数学运算的情况下进行大量此类转换,则可以将 MXCSR1 中的舍入模式设置为截断,然后在完成后将其设置回去。

否则,您可以使用 v & ~(v>>8) 保留 8 个最高有效位,并将一些或所有较低位清零,包括 MSB 下面可能设置的第 8 位。这足以确保所有舍入模式永远不会舍入到下一个 2 的幂。它始终保留 8 个 MSB,因为 v>>8 移入 8 个零,因此倒置为 8 个。在较低的位位置,无论 MSB 在哪里,8 个零都会从较高的位置移过那里,因此它永远不会清除任何整数的最高有效位。根据 MSB 下方的设置位排列方式,它可能会或可能不会清除更多低于 8 个最重要的位。

转换后,我们对位模式使用整数移位,将指数(和符号位)移至底部,并使用饱和减法消除偏差。如果原始 32 位输入中没有设置任何位,我们使用 min 将结果设置为 32。

__m256i avx2_lzcnt_epi32 (__m256i v) {
    // prevent value from being rounded up to the next power of two
    v = _mm256_andnot_si256(_mm256_srli_epi32(v, 8), v); // keep 8 MSB

    v = _mm256_castps_si256(_mm256_cvtepi32_ps(v)); // convert an integer to float
    v = _mm256_srli_epi32(v, 23); // shift down the exponent
    v = _mm256_subs_epu16(_mm256_set1_epi32(158), v); // undo bias
    v = _mm256_min_epi16(v, _mm256_set1_epi32(32)); // clamp at 32

    return v;
}

脚注 1:fp->int 转换可用于截断 (cvtt),但 int->fp 转换仅可用于默认舍入(受 MXCSR 约束)。

AVX512F 引入了 512 位向量的舍入模式覆盖,这将解决问题,__m512 _mm512_cvt_roundepi32_ps( __m512i a, int r);。但是所有带 AVX512F 的 CPU 也支持 AVX512CD,所以你可以只使用 _mm512_lzcnt_epi32。对于 AVX512VL,_mm256_lzcnt_epi32

问题也被标记为AVX,但是AVX中没有整数处理的说明,这意味着需要在支持AVX的平台上回退到SSE但是不是 AVX2。我在下面展示了一个经过详尽测试但有点行人的版本。这里的基本思想与其他答案一样,前导零的计数由整数到浮点转换期间发生的浮点归一化确定。结果的指数与前导零的计数一一对应,除了在参数为零的情况下结果是错误的。概念上:

clz (a) = (158 - (float_as_uint32 (uint32_to_float_rz (a)) >> 23)) + (a == 0)

其中 float_as_uint32() 是重新解释转换,uint32_to_float_rz() 是从无符号整数到浮点数的转换 并截断 。正常的舍入转换可能会将转换结果提高到下一个 2 的幂,从而导致前导零位的计数不正确。

SSE 不提供截断整数到浮点数的转换作为单个指令,也不提供无符号整数的转换。需要模拟此功能。仿真不需要精确,只要它不改变转换结果的幅度即可。截断部分由来自 反转 - 右移 - 和 n 技术处理。要使用有符号转换,我们在转换前将数字减半,然后在转换后加倍并递增:

float approximate_uint32_to_float_rz (uint32_t a)
{
    float r = (float)(int)((a >> 1) & ~(a >> 2));
    return r + r + 1.0f;
}

这种方法在下面的 sse_clz() 中被翻译成 SSE 内在函数。

#include <stdio.h>
#include <stdlib.h>
#include <stdint.h>
#include <string.h>
#include "immintrin.h"

/* compute count of leading zero bits using floating-point normalization.

   clz(a) = (158 - (float_as_uint32 (uint32_to_float_rz (a)) >> 23)) + (a == 0)

   The problematic part here is uint32_to_float_rz(). SSE does not offer
   conversion of unsigned integers, and no rounding modes in integer to
   floating-point conversion. Since all we need is an approximate version
   that preserves order of magnitude:

   float approximate_uint32_to_float_rz (uint32_t a)
   {
      float r = (float)(int)((a >> 1) & ~(a >> 2));
      return r + r + 1.0f;
   }
*/  
__m128i sse_clz (__m128i a) 
{
    __m128 fp1 = _mm_set_ps1 (1.0f);
    __m128i zero = _mm_set1_epi32 (0);
    __m128i i158 = _mm_set1_epi32 (158);
    __m128i iszero = _mm_cmpeq_epi32 (a, zero);
    __m128i lsr1 = _mm_srli_epi32 (a, 1);
    __m128i lsr2 = _mm_srli_epi32 (a, 2);
    __m128i atrunc = _mm_andnot_si128 (lsr2, lsr1);
    __m128 atruncf = _mm_cvtepi32_ps (atrunc);
    __m128 atruncf2 = _mm_add_ps (atruncf, atruncf);
    __m128 conv = _mm_add_ps (atruncf2, fp1);
    __m128i convi = _mm_castps_si128 (conv);
    __m128i lsr23 = _mm_srli_epi32 (convi, 23);
    __m128i res = _mm_sub_epi32 (i158, lsr23);
    return _mm_sub_epi32 (res, iszero);
}

/* Portable reference implementation of 32-bit count of leading zeros */    
int clz32 (uint32_t a)
{
    uint32_t r = 32;
    if (a >= 0x00010000) { a >>= 16; r -= 16; }
    if (a >= 0x00000100) { a >>=  8; r -=  8; }
    if (a >= 0x00000010) { a >>=  4; r -=  4; }
    if (a >= 0x00000004) { a >>=  2; r -=  2; }
    r -= a - (a & (a >> 1));
    return r;
}

/* Test floating-point based count leading zeros exhaustively */
int main (void)
{
    __m128i res;
    uint32_t resi[4], refi[4];
    uint32_t count = 0;
    do {
        refi[0] = clz32 (count);
        refi[1] = clz32 (count + 1);
        refi[2] = clz32 (count + 2);
        refi[3] = clz32 (count + 3);
        res = sse_clz (_mm_set_epi32 (count + 3, count + 2, count + 1, count));
        memcpy (resi, &res, sizeof resi);
        if ((resi[0] != refi[0]) || (resi[1] != refi[1]) ||
            (resi[2] != refi[2]) || (resi[3] != refi[3])) {
            printf ("error @ %08x %08x %08x %08x\n",
                    count, count+1, count+2, count+3);
            return EXIT_FAILURE;
        }
        count += 4;
    } while (count);
    return EXIT_SUCCESS;
}