我可以使用 AVX FMA 单元进行位精确的 52 位整数乘法吗?

Can I use the AVX FMA units to do bit-exact 52 bit integer multiplications?

AXV2 没有任何大于 32 位的源的整数乘法。它确实提供 32 x 32 -> 32 multiplies, as well as 32 x 32 -> 64 乘法 1,但没有 64 位源。

假设我需要一个输入大于 32 位但小于或等于 52 位的无符号乘法 - 我可以简单地使用浮点 DP multiply 或 FMA 指令吗,输出是当整数输入和结果可以用 52 位或更少的位表示时(即在 [0, 2^52-1] 范围内)?

我想要产品的所有 104 位的更一般情况如何?或者整数乘积占用超过 52 位的情况(即乘积在位索引中具有非零值 > 52)——但我只想要低 52 位?在后一种情况下,MUL 会给我更高的位并舍弃一些较低的位(也许这就是 IFMA 的帮助?)。

编辑: 事实上,基于 this answer,它也许可以做任何高达 2^53 的事情——我忘记了隐含的前导 1 在尾数有效地给你另一位之前。


1 有趣的是,64 位产品 PMULDQ 操作的延迟是 32 位 PMULLD 版本的一半,吞吐量是 Mysticial 的两倍 在评论中。

嗯,你当然可以对整数进行 FP-lane 操作。它们将始终是精确的:虽然有一些 SSE 指令不能保证正确的 IEEE-754 精度和舍入,但它们毫无例外地是没有整数范围的指令,所以无论如何都不是您正在查看的指令。底线:Addition/subtraction/multiplication 在整数域中始终是精确的,即使您是在压缩浮点数上进行计算也是如此。

至于四精度浮点数(>52 位尾数),不,它们不受支持,并且在可预见的将来可能不会。只是对他们的要求不高。它们出现在一些 SPARC 时代的工作站架构中,但老实说,它们只是开发人员对如何编写数值稳定算法的不完全理解的绷带,并且随着时间的推移它们逐渐淡出。

事实证明,宽整数运算非常不适合 SSE。我最近在实现一个大整数库时真的尝试利用它,老实说它对我没有好处。 x86 设计 用于多字运算;您可以在 ADC(产生并消耗进位位)和 IDIV(允许除数的宽度是被除数的两倍,只要商不大于被除数的宽度)等操作中看到它,这是一个约束对多词除法毫无用处)。但是多字运算本质上是顺序的,而 SSE 本质上是并行的。如果您足够幸运,您的数字 刚好 位可以放入 FP 尾数,那么恭喜您。但是,如果您通常有大整数,SSE 可能不会成为您的朋友。

是的,这是可能的。但是对于 AVX2,它不太可能比 MULX/ADCX/ADOX.

的标量方法更好

对于不同的 input/output 域,这种方法实际上有无数种变体。我只会介绍其中的 3 个,但一旦您了解它们的工作原理,它们就很容易概括。

免责声明:

  • 这里的所有解决方案都假设舍入模式是四舍五入。
  • 不推荐使用快速数学优化标志,因为这些解决方案依赖于严格的 IEEE。

范围内的有符号双打:[-251, 251]

//  A*B = L + H*2^52
//  Input:  A and B are in the range [-2^51, 2^51]
//  Output: L and H are in the range [-2^51, 2^51]
void mul52_signed(__m256d& L, __m256d& H, __m256d A, __m256d B){
    const __m256d ROUND = _mm256_set1_pd(30423614405477505635920876929024.);    //  3 * 2^103
    const __m256d SCALE = _mm256_set1_pd(1. / 4503599627370496);                //  1 / 2^52

    //  Multiply and add normalization constant. This forces the multiply
    //  to be rounded to the correct number of bits.
    H = _mm256_fmadd_pd(A, B, ROUND);

    //  Undo the normalization.
    H = _mm256_sub_pd(H, ROUND);

    //  Recover the bottom half of the product.
    L = _mm256_fmsub_pd(A, B, H);

    //  Correct the scaling of H.
    H = _mm256_mul_pd(H, SCALE);
}

这是最简单的一种,也是唯一一种可以与标量方法竞争的方法。最终缩放是可选的,具体取决于您要对输出执行的操作。所以这可以被认为只有 3 条指令。但它也是最没用的,因为输入和输出都是浮点值。

两个 FMA 保持融合是绝对关键的。这就是快速数学优化可以打破局面的地方。如果第一个 FMA 被打破,那么 L 不再保证在 [-2^51, 2^51] 范围内。如果第二个FMA被向上突破,L就完全错了


范围内的有符号整数:[-251, 251]

//  A*B = L + H*2^52
//  Input:  A and B are in the range [-2^51, 2^51]
//  Output: L and H are in the range [-2^51, 2^51]
void mul52_signed(__m256i& L, __m256i& H, __m256i A, __m256i B){
    const __m256d CONVERT_U = _mm256_set1_pd(6755399441055744);     //  3*2^51
    const __m256d CONVERT_D = _mm256_set1_pd(1.5);

    __m256d l, h, a, b;

    //  Convert to double
    A = _mm256_add_epi64(A, _mm256_castpd_si256(CONVERT_U));
    B = _mm256_add_epi64(B, _mm256_castpd_si256(CONVERT_D));
    a = _mm256_sub_pd(_mm256_castsi256_pd(A), CONVERT_U);
    b = _mm256_sub_pd(_mm256_castsi256_pd(B), CONVERT_D);

    //  Get top half. Convert H to int64.
    h = _mm256_fmadd_pd(a, b, CONVERT_U);
    H = _mm256_sub_epi64(_mm256_castpd_si256(h), _mm256_castpd_si256(CONVERT_U));

    //  Undo the normalization.
    h = _mm256_sub_pd(h, CONVERT_U);

    //  Recover bottom half.
    l = _mm256_fmsub_pd(a, b, h);

    //  Convert L to int64
    l = _mm256_add_pd(l, CONVERT_D);
    L = _mm256_sub_epi64(_mm256_castpd_si256(l), _mm256_castpd_si256(CONVERT_D));
}

在第一个示例的基础上,我们将其与 .

的通用版本相结合

这个更有用,因为你正在处理整数。但即使使用快速转换技巧,大部分时间还是会花在转换上。幸运的是,如果多次乘以相同的操作数,您可以消除一些输入转换。


范围内的无符号整数:[0, 252)

//  A*B = L + H*2^52
//  Input:  A and B are in the range [0, 2^52)
//  Output: L and H are in the range [0, 2^52)
void mul52_unsigned(__m256i& L, __m256i& H, __m256i A, __m256i B){
    const __m256d CONVERT_U = _mm256_set1_pd(4503599627370496);     //  2^52
    const __m256d CONVERT_D = _mm256_set1_pd(1);
    const __m256d CONVERT_S = _mm256_set1_pd(1.5);

    __m256d l, h, a, b;

    //  Convert to double
    A = _mm256_or_si256(A, _mm256_castpd_si256(CONVERT_U));
    B = _mm256_or_si256(B, _mm256_castpd_si256(CONVERT_D));
    a = _mm256_sub_pd(_mm256_castsi256_pd(A), CONVERT_U);
    b = _mm256_sub_pd(_mm256_castsi256_pd(B), CONVERT_D);

    //  Get top half. Convert H to int64.
    h = _mm256_fmadd_pd(a, b, CONVERT_U);
    H = _mm256_xor_si256(_mm256_castpd_si256(h), _mm256_castpd_si256(CONVERT_U));

    //  Undo the normalization.
    h = _mm256_sub_pd(h, CONVERT_U);

    //  Recover bottom half.
    l = _mm256_fmsub_pd(a, b, h);

    //  Convert L to int64
    l = _mm256_add_pd(l, CONVERT_S);
    L = _mm256_sub_epi64(_mm256_castpd_si256(l), _mm256_castpd_si256(CONVERT_S));

    //  Make Correction
    H = _mm256_sub_epi64(H, _mm256_srli_epi64(L, 63));
    L = _mm256_and_si256(L, _mm256_set1_epi64x(0x000fffffffffffff));
}

终于得到了原问题的答案。这通过调整转换并添加更正步骤来构建有符号整数解决方案。

但此时,我们有 13 条指令 - 其中一半是高延迟指令,还不包括无数 FP <-> int 旁路延迟。因此,这不太可能赢得任何基准测试。相比之下,64 x 64 -> 128-bit SIMD 乘法可以用 16 条指令完成(如果对输入进行预处理,则为 14 条。)

如果舍入方式为向下舍入或舍入为零,则可以省略校正步骤。唯一重要的指令是 h = _mm256_fmadd_pd(a, b, CONVERT_U);。因此在 AVX512 上,您可以覆盖该指令的舍入并单独保留舍入模式。


最后的想法:

值得注意的是,252的操作范围可以通过调整魔法常数来减小。这可能对第一个解决方案(浮点解决方案)有用,因为它为您提供了额外的尾数以用于浮点累加。这使您无需像在最后两个解决方案中那样不断地在 int64 和 double 之间来回转换。

虽然这里的 3 个示例不太可能比标量方法更好,但 AVX512 几乎肯定会打破平衡。 Knights Landing 的 ADCX 和 ADOX 吞吐量尤其低。

当然,当 AVX512-IFMA 出来时,所有这些都没有实际意义。这将完整的 52 x 52 -> 104-bit 产品减少到 2 条指令,并免费提供累积。

进行多字整数运算的一种方法是使用 double-double arithmetic。让我们从一些双双乘法代码开始

#include <math.h>
typedef struct {
  double hi;
  double lo;
} doubledouble;

static doubledouble quick_two_sum(double a, double b) {
  double s = a + b;
  double e = b - (s - a);
  return (doubledouble){s, e};
}

static doubledouble two_prod(double a, double b) {
  double p = a*b;
  double e = fma(a, b, -p);
  return (doubledouble){p, e};
}

doubledouble df64_mul(doubledouble a, doubledouble b) {
  doubledouble p = two_prod(a.hi, b.hi);
  p.lo += a.hi*b.lo;
  p.lo += a.lo*b.hi;
  return quick_two_sum(p.hi, p.lo);
}

函数two_prod可以用两条指令完成整数53bx53b -> 106b。函数 df64_mul 可以做整数 106bx106b -> 106b.

让我们将其与具有整数硬件的整数 128bx128b -> 128b 进行比较。

__int128 mul128(__int128 a, __int128 b) {
  return a*b;
}

mul128

的程序集
imul    rsi, rdx
mov     rax, rdi
imul    rcx, rdi
mul     rdx
add     rcx, rsi
add     rdx, rcx

df64_mul 的程序集(用 gcc -O3 -S i128.c -masm=intel -mfma -ffp-contract=off 编译)

vmulsd      xmm4, xmm0, xmm2
vmulsd      xmm3, xmm0, xmm3
vmulsd      xmm1, xmm2, xmm1
vfmsub132sd xmm0, xmm4, xmm2
vaddsd      xmm3, xmm3, xmm0
vaddsd      xmm1, xmm3, xmm1
vaddsd      xmm0, xmm1, xmm4
vsubsd      xmm4, xmm0, xmm4
vsubsd      xmm1, xmm1, xmm4

mul128 执行三个标量乘法和两个标量 additions/subtractions,而 df64_mul 执行 3 个 SIMD 乘法、1 个 SIMD FMA 和 5 个 SIMD additions/subtractions。我没有描述这些方法,但对我来说,df64_mul 可以胜过 mul128,每个 AVX 寄存器使用 4-double(将 sd 更改为 pdxmmymm).


很容易说问题是切换回整数域。但为什么这是必要的?您可以在浮点域中做任何事情。让我们看一些例子。我发现使用 float 进行单元测试比使用 double.

更容易
doublefloat two_prod(float a, float b) {
  float p = a*b;
  float e = fma(a, b, -p);
  return (doublefloat){p, e};
}

//3202129*4807935=15395628093615
x = two_prod(3202129,4807935)
int64_t hi = p, lo = e, s = hi+lo
//p = 1.53956280e+13, e = 1.02575000e+05  
//hi = 15395627991040, lo = 102575, s = 15395628093615

//1450779*1501672=2178594202488
y = two_prod(1450779, 1501672)
int64_t hi = p, lo = e, s = hi+lo 
//p = 2.17859424e+12, e = -4.00720000e+04
//hi = 2178594242560 lo = -40072, s = 2178594202488

所以我们最终得到不同的范围,在第二种情况下,错误 (e) 甚至是负数,但总和仍然正确。我们甚至可以将两个 doublefloat 值 xy 加在一起(一旦我们知道如何进行双双加法 - 请参阅最后的代码)并得到 15395628093615+2178594202488。无需对结果进行归一化。

但是加法带来了双双运算的主要问题。即,addition/subtraction 很慢,例如128b+128b -> 128b needs at least 11 floating point additions 而对于整数它只需要两个(addadc)。

因此,如果一个算法重乘法而轻加法,那么用 double-double 进行多字整数运算可能会获胜。


附带说明一下,C 语言足够灵活,可以实现整数完全通过浮点硬件实现的实现。 int 可能是 24 位(来自单个浮点数),long 可能是 54 位。 (来自双浮点),long long 可以是 106 位(来自双精度)。 C 甚至不需要补码,因此整数可以像浮点数一样对负数使用带符号的大小。


这里是双倍乘法和加法的工作 C 代码(我没有实现除法或其他操作,例如 sqrt 但有论文展示了如何做到这一点)以防有人想玩它。看看这是否可以针对整数进行优化会很有趣。

//if compiling with -mfma you must also use -ffp-contract=off
//float-float is easier to debug. If you want double-double replace
//all float words with double and fmaf with fma 
#include <stdio.h>
#include <math.h>
#include <inttypes.h>
#include <x86intrin.h>
#include <stdlib.h>

//#include <float.h>

typedef struct {
  float hi;
  float lo;
} doublefloat;

typedef union {
  float f;
  int i;
  struct {
    unsigned mantisa : 23;
    unsigned exponent: 8;
    unsigned sign: 1;
  };
} float_cast;

void print_float(float_cast a) {
  printf("%.8e, 0x%x, mantisa 0x%x, exponent 0x%x, expondent-127 %d, sign %u\n", a.f, a.i, a.mantisa, a.exponent, a.exponent-127, a.sign);
}

void print_doublefloat(doublefloat a) {
  float_cast hi = {a.hi};
  float_cast lo = {a.lo};
  printf("hi: "); print_float(hi);
  printf("lo: "); print_float(lo);
}

doublefloat quick_two_sum(float a, float b) {
  float s = a + b;
  float e = b - (s - a);
  return (doublefloat){s, e};
  // 3 add
}

doublefloat two_sum(float a, float b) {
  float s = a + b;
  float v = s - a;
  float e = (a - (s - v)) + (b - v);
  return (doublefloat){s, e};
  // 6 add 
}

doublefloat df64_add(doublefloat a, doublefloat b) {
  doublefloat s, t;
  s = two_sum(a.hi, b.hi);
  t = two_sum(a.lo, b.lo);
  s.lo += t.hi;
  s = quick_two_sum(s.hi, s.lo);
  s.lo += t.lo;
  s = quick_two_sum(s.hi, s.lo);
  return s;
  // 2*two_sum, 2 add, 2*quick_two_sum = 2*6 + 2 + 2*3 = 20 add
}

doublefloat split(float a) {
  //#define SPLITTER (1<<27) + 1
#define SPLITTER (1<<12) + 1
  float t = (SPLITTER)*a;
  float hi = t - (t - a);
  float lo = a - hi;
  return (doublefloat){hi, lo};
  // 1 mul, 3 add
}

doublefloat split_sse(float a) {
  __m128 k = _mm_set1_ps(4097.0f);
  __m128 a4 = _mm_set1_ps(a);
  __m128 t = _mm_mul_ps(k,a4);
  __m128 hi4 = _mm_sub_ps(t,_mm_sub_ps(t, a4));
  __m128 lo4 = _mm_sub_ps(a4, hi4);
  float tmp[4];
  _mm_storeu_ps(tmp, hi4);
  float hi = tmp[0];
  _mm_storeu_ps(tmp, lo4);
  float lo = tmp[0];
  return (doublefloat){hi,lo};

}

float mult_sub(float a, float b, float c) {
  doublefloat as = split(a), bs = split(b);
  //print_doublefloat(as);
  //print_doublefloat(bs);
  return ((as.hi*bs.hi - c) + as.hi*bs.lo + as.lo*bs.hi) + as.lo*bs.lo;
  // 4 mul, 4 add, 2 split = 6 mul, 10 add
}

doublefloat two_prod(float a, float b) {
  float p = a*b;
  float e = mult_sub(a, b, p);
  return (doublefloat){p, e};
  // 1 mul, one mult_sub
  // 7 mul, 10 add
}

float mult_sub2(float a, float b, float c) {
  doublefloat as = split(a);
  return ((as.hi*as.hi -c ) + 2*as.hi*as.lo) + as.lo*as.lo;
}

doublefloat two_sqr(float a) {
  float p = a*a;
  float e = mult_sub2(a, a, p);
  return (doublefloat){p, e};
}

doublefloat df64_mul(doublefloat a, doublefloat b) {
  doublefloat p = two_prod(a.hi, b.hi);
  p.lo += a.hi*b.lo;
  p.lo += a.lo*b.hi;
  return quick_two_sum(p.hi, p.lo);
  //two_prod, 2 add, 2mul, 1 quick_two_sum = 9 mul, 15 add 
  //or 1 mul, 1 fma, 2add 2mul, 1 quick_two_sum = 3 mul, 1 fma, 5 add
}

doublefloat df64_sqr(doublefloat a) {
  doublefloat p = two_sqr(a.hi);
  p.lo += 2*a.hi*a.lo;
  return quick_two_sum(p.hi, p.lo);
}

int float2int(float a) {
  int M = 0xc00000; //1100 0000 0000 0000 0000 0000
  a += M;
  float_cast x;
  x.f = a;
  return x.i - 0x4b400000;
}

doublefloat add22(doublefloat a, doublefloat b) {
  float r = a.hi + b.hi;
  float s = fabsf(a.hi) > fabsf(b.hi) ?
    (((a.hi - r) + b.hi) + b.lo ) + a.lo :
    (((b.hi - r) + a.hi) + a.lo ) + b.lo;
  return two_sum(r, s);  
  //11 add 
}

int main(void) {
  //print_float((float_cast){1.0f});
  //print_float((float_cast){-2.0f});
  //print_float((float_cast){0.0f});
  //print_float((float_cast){3.14159f});
  //print_float((float_cast){1.5f});
  //print_float((float_cast){3.0f});
  //print_float((float_cast){7.0f});
  //print_float((float_cast){15.0f});
  //print_float((float_cast){31.0f});

  //uint64_t t = 0xffffff;
  //print_float((float_cast){1.0f*t});
  //printf("%" PRId64 " %" PRIx64 "\n", t*t,t*t);

  /*
    float_cast t1;
    t1.mantisa = 0x7fffff;
    t1.exponent = 0xfe;
    t1.sign = 0;
    print_float(t1);
  */
  //doublefloat z = two_prod(1.0f*t, 1.0f*t);
  //print_doublefloat(z);
  //double z2 = (double)z.hi + (double)z.lo;
  //printf("%.16e\n", z2);
  doublefloat s = {0};
  int64_t si = 0;
  for(int i=0; i<100000; i++) {
    int ai = rand()%0x800, bi = rand()%0x800000;
    float a = ai, b = bi;
    doublefloat z = two_prod(a,b);
    int64_t zi = (int64_t)ai*bi;
    //print_doublefloat(z);
    //s = df64_add(s,z);
    s = add22(s,z);
    si += zi;
    print_doublefloat(z);
    printf("%d %d ", ai,bi);
    int64_t h = z.hi;
    int64_t l = z.lo;
    int64_t t = h+l;
    //if(t != zi) printf("%" PRId64 " %" PRId64 "\n", h, l);

    printf("%" PRId64 " %" PRId64 " %" PRId64 " %" PRId64 "\n", zi, h, l, h+l);

    h = s.hi;
    l = s.lo;
    t = h + l;
    //if(si != t) printf("%" PRId64 " %" PRId64 "\n", h, l);

    if(si > (1LL<<48)) {
      printf("overflow after %d iterations\n", i); break;
    }
  }

  print_doublefloat(s);
  printf("%" PRId64 "\n", si);
  int64_t x = s.hi;
  int64_t y = s.lo;
  int64_t z = x+y;
  //int hi = float2int(s.hi);
  printf("%" PRId64 " %" PRId64 " %" PRId64 "\n", z,x,y);
}