高效计算三个无符号整数的平均值(无溢出)

Efficient computation of the average of three unsigned integers (without overflow)

有一个existing question“3个长整数的平均值”,专门用于有效计算三个有符号整数的平均值。

然而,使用无符号整数可以进行额外的优化,但不适用于上一个问题中涵盖的场景。这个问题是关于三个 unsigned 整数的平均值的有效计算,其中平均值向零四舍五入,即在数学术语中我想计算 ⌊ (a + b + c) / 3 ⌋.

计算此平均值的一种直接方法是

 avg = a / 3 + b / 3 + c / 3 + (a % 3 + b % 3 + c % 3) / 3;

首先,现代优化编译器会将除法转换为倒数加移位的乘法,并将模运算转换为反乘和减法,其中反乘可能使用 scale_add 习语可用于许多体系结构,例如lea 在 x86_64 上,add 在 ARM 上 lsl #n,在 NVIDIA GPU 上 iscadd

在尝试以适用于许多常见平台的通用方式优化上述内容时,我观察到整数运算的成本通常处于 逻辑 ≤ (add | sub) ≤ shiftscale_addmul。这里的成本是指所有的延迟、吞吐量限制和功耗。当处理的整数类型比本地寄存器宽度更宽时,任何此类差异都会变得更加明显,例如在 32 位处理器上处理 uint64_t 数据时。

因此,我的优化策略是尽量减少指令数量并尽可能用“廉价”操作替换“昂贵”操作,同时不增加寄存器压力并为广泛的无序处理器保留可利用的并行性。

第一个观察结果是,我们可以通过首先应用生成和值和进位值的 CSA(进位保存加法器)将三个操作数的和减少为两个操作数的和,其中进位值有两倍总和值的权重。在大多数处理器上,基于软件的 CSA 的成本是五个逻辑。某些处理器,如 NVIDIA GPU,有一条 LOP3 指令,可以一次性计算出三个操作数的任意逻辑表达式,在这种情况下,CSA 浓缩为两个 LOP3(注意:我还没有说服CUDA 编译器发出这两个 LOP3;它目前产生四个 LOP3!)。

第二个观察结果是,因为我们正在计算除以 3 的模数,所以不需要反向乘法来计算它。我们可以改为使用 dividend % 3 = ((dividend / 3) + dividend) & 3,将模减少为 add 加上 logical 因为我们已经有了划分结果。这是一般算法的一个实例: dividend % (2n-1) = ((dividend / (2n-1) + dividend ) & (2n-1).

最后在校正项(a % 3 + b % 3 + c % 3) / 3中除以3,我们不需要泛型除以3的代码。由于被除数很小,在[0, 6],我们可以简化x / 3(3 * x) / 8 只需要一个 scale_add 加上一个 shift.

下面的代码显示了我当前正在进行的工作。使用 Compiler Explorer 检查为各种平台生成的代码显示了我所期望的紧凑代码(当使用 -O3 编译时)。

然而,在我的 Ivy Bridge x86_64 机器上使用英特尔 13.x 编译器对代码进行计时时,一个缺陷变得明显:虽然我的代码改善了延迟(从 18 个周期到 15 个周期 uint64_t 数据)与简单版本相比,吞吐量变差(对于 uint64_t 数据,从每 6.8 个周期一个结果到每 8.5 个周期一个结果)。更仔细地查看汇编代码,原因就很明显了:我基本上设法将代码从大致的三向并行性降低到大致的双向并行性。

是否有一种普遍适用的优化技术,对普通处理器有益,特别是所有类型的 x86 和 ARM 以及 GPU,它可以保持更多的并行性?或者,是否有一种优化技术可以进一步减少总体操作数以弥补并行度的降低?校正项的计算(下面代码中的 tail)似乎是一个很好的目标。简化 (carry_mod_3 + sum_mod_3) / 2 看起来很诱人,但对于九种可能组合中的一种,结果不正确。

#include <stdio.h>
#include <stdlib.h>
#include <stdint.h>

#define BENCHMARK           (1)
#define SIMPLE_COMPUTATION  (0)

#if BENCHMARK
#define T uint64_t
#else // !BENCHMARK
#define T uint8_t
#endif // BENCHMARK

T average_of_3 (T a, T b, T c) 
{
    T avg;

#if SIMPLE_COMPUTATION
    avg = a / 3 + b / 3 + c / 3 + (a % 3 + b % 3 + c % 3) / 3;
#else // !SIMPLE_COMPUTATION
    /* carry save adder */
    T a_xor_b = a ^ b;
    T sum = a_xor_b ^ c;
    T carry = (a_xor_b & c) | (a & b);
    /* here 2 * carry + sum = a + b + c */
    T sum_div_3 = (sum / 3);                                   // {MUL|MULHI}, SHR
    T sum_mod_3 = (sum + sum_div_3) & 3;                       // ADD, AND

    if (sizeof (size_t) == sizeof (T)) { // "native precision" (well, not always)
        T two_carry_div_3 = (carry / 3) * 2;                   // MULHI, ANDN
        T two_carry_mod_3 = (2 * carry + two_carry_div_3) & 6; // SCALE_ADD, AND
        T head = two_carry_div_3 + sum_div_3;                  // ADD
        T tail = (3 * (two_carry_mod_3 + sum_mod_3)) / 8;      // ADD, SCALE_ADD, SHR
        avg = head + tail;                                     // ADD
    } else {
        T carry_div_3 = (carry / 3);                           // MUL, SHR
        T carry_mod_3 = (carry + carry_div_3) & 3;             // ADD, AND
        T head = (2 * carry_div_3 + sum_div_3);                // SCALE_ADD
        T tail = (3 * (2 * carry_mod_3 + sum_mod_3)) / 8;      // SCALE_ADD, SCALE_ADD, SHR
        avg = head + tail;                                     // ADD
    }
#endif // SIMPLE_COMPUTATION
    return avg;
}

#if !BENCHMARK
/* Test correctness on 8-bit data exhaustively. Should catch most errors */
int main (void)
{
    T a, b, c, res, ref;
    a = 0;
    do {
        b = 0;
        do {
            c = 0;
            do {
                res = average_of_3 (a, b, c);
                ref = ((uint64_t)a + (uint64_t)b + (uint64_t)c) / 3;
                if (res != ref) {
                    printf ("a=%08x  b=%08x  c=%08x  res=%08x  ref=%08x\n", 
                            a, b, c, res, ref);
                    return EXIT_FAILURE;
                }
                c++;
            } while (c);
            b++;
        } while (b);
        a++;
    } while (a);
    return EXIT_SUCCESS;
}

#else // BENCHMARK

#include <math.h>

// A routine to give access to a high precision timer on most systems.
#if defined(_WIN32)
#if !defined(WIN32_LEAN_AND_MEAN)
#define WIN32_LEAN_AND_MEAN
#endif
#include <windows.h>
double second (void)
{
    LARGE_INTEGER t;
    static double oofreq;
    static int checkedForHighResTimer;
    static BOOL hasHighResTimer;

    if (!checkedForHighResTimer) {
        hasHighResTimer = QueryPerformanceFrequency (&t);
        oofreq = 1.0 / (double)t.QuadPart;
        checkedForHighResTimer = 1;
    }
    if (hasHighResTimer) {
        QueryPerformanceCounter (&t);
        return (double)t.QuadPart * oofreq;
    } else {
        return (double)GetTickCount() * 1.0e-3;
    }
}
#elif defined(__linux__) || defined(__APPLE__)
#include <stddef.h>
#include <sys/time.h>
double second (void)
{
    struct timeval tv;
    gettimeofday(&tv, NULL);
    return (double)tv.tv_sec + (double)tv.tv_usec * 1.0e-6;
}
#else
#error unsupported platform
#endif

#define N  (3000000)
int main (void)
{
    double start, stop, elapsed = INFINITY;
    int i, k;
    T a, b;
    T avg0  = 0xffffffff,  avg1 = 0xfffffffe;
    T avg2  = 0xfffffffd,  avg3 = 0xfffffffc;
    T avg4  = 0xfffffffb,  avg5 = 0xfffffffa;
    T avg6  = 0xfffffff9,  avg7 = 0xfffffff8;
    T avg8  = 0xfffffff7,  avg9 = 0xfffffff6;
    T avg10 = 0xfffffff5, avg11 = 0xfffffff4;
    T avg12 = 0xfffffff2, avg13 = 0xfffffff2;
    T avg14 = 0xfffffff1, avg15 = 0xfffffff0;

    a = 0x31415926;
    b = 0x27182818;
    avg0 = average_of_3 (a, b, avg0);
    for (k = 0; k < 5; k++) {
        start = second();
        for (i = 0; i < N; i++) {
            avg0 = average_of_3 (a, b, avg0);
            avg0 = average_of_3 (a, b, avg0);
            avg0 = average_of_3 (a, b, avg0);
            avg0 = average_of_3 (a, b, avg0);
            avg0 = average_of_3 (a, b, avg0);
            avg0 = average_of_3 (a, b, avg0);
            avg0 = average_of_3 (a, b, avg0);
            avg0 = average_of_3 (a, b, avg0);
            avg0 = average_of_3 (a, b, avg0);
            avg0 = average_of_3 (a, b, avg0);
            avg0 = average_of_3 (a, b, avg0);
            avg0 = average_of_3 (a, b, avg0);
            avg0 = average_of_3 (a, b, avg0);
            avg0 = average_of_3 (a, b, avg0);
            avg0 = average_of_3 (a, b, avg0);
            avg0 = average_of_3 (a, b, avg0);
            b = (b + avg0) ^ a;
            a = (a ^ b) + avg0;
        }
        stop = second();
        elapsed = fmin (stop - start, elapsed);
    }
    printf ("a=%016llx b=%016llx avg=%016llx", 
            (uint64_t)a, (uint64_t)b, (uint64_t)avg0);
    printf ("\rlatency:    each average_of_3() took  %.6e seconds\n", 
            elapsed / 16 / N);


    a = 0x31415926;
    b = 0x27182818;
    avg0 = average_of_3 (a, b, avg0);
    for (k = 0; k < 5; k++) {
        start = second();
        for (i = 0; i < N; i++) {
            avg0  = average_of_3 (a, b, avg0);
            avg1  = average_of_3 (a, b, avg1);
            avg2  = average_of_3 (a, b, avg2);
            avg3  = average_of_3 (a, b, avg3);
            avg4  = average_of_3 (a, b, avg4);
            avg5  = average_of_3 (a, b, avg5);
            avg6  = average_of_3 (a, b, avg6);
            avg7  = average_of_3 (a, b, avg7);
            avg8  = average_of_3 (a, b, avg8);
            avg9  = average_of_3 (a, b, avg9);
            avg10 = average_of_3 (a, b, avg10);
            avg11 = average_of_3 (a, b, avg11);
            avg12 = average_of_3 (a, b, avg12);
            avg13 = average_of_3 (a, b, avg13);
            avg14 = average_of_3 (a, b, avg14);
            avg15 = average_of_3 (a, b, avg15);
            b = (b + avg0) ^ a;
            a = (a ^ b) + avg0;
        }
        stop = second();
        elapsed = fmin (stop - start, elapsed);
    }
    printf ("a=%016llx b=%016llx avg=%016llx", (uint64_t)a, (uint64_t)b, 
            (uint64_t)(avg0 + avg1 + avg2 + avg3 + avg4 + avg5 + avg6 + avg7 + 
                       avg8 + avg9 +avg10 +avg11 +avg12 +avg13 +avg14 +avg15));
    printf ("\rthroughput: each average_of_3() took  %.6e seconds\n", 
            elapsed / 16 / N);

    return EXIT_SUCCESS;
}

#endif // BENCHMARK

我怀疑 SIMPLE 正在通过 CSEing 和提升 a/3+b/3a%3+b%3 循环外,将这些结果重新用于所有 16 个 avg0..15 结果来击败吞吐量基准.

(SIMPLE 版本比 tricky 版本可以完成更多的工作;在该版本中实际上只有 a ^ ba & b。)

强制函数不内联会引入更多 front-end 开销,但确实会使您的版本获胜,正如我们预期的那样,它应该在具有深度 out-of-order 执行缓冲区重叠的 CPU 上独立工作。对于吞吐量基准,有很多 ILP 可以跨迭代找到。 (我没有仔细查看 non-inline 版本的 asm。)

https://godbolt.org/z/j95qn3(在 Godbolt 的 SKX CPUs 上使用 __attribute__((noinline))clang -O3 -march=skylake)显示简单方法的吞吐量为 2.58 纳秒,您的方法为 2.48 纳秒。与 1.17 纳秒吞吐量相比,简单版本具有内联。

-march=skylake 允许 mulx 更灵活 full-multiply,但除此之外 BMI2 没有任何好处。 andn 未使用;您用 mulhi / andn 评论的行是 mulx 到 RCX / and rcx, -2 中,只需要 sign-extended 立即数。


另一种不强制 call/ret 开销的方法是像 Preventing compiler optimizations while benchmarking 中的内联汇编(Chandler Carruth 的 CppCon 演讲有一些他如何使用包装器的例子),或者 Google 基准 benchmark::DoNotOptimize.

具体来说,每个 avgX = average_of_3 (a, b, avgX); 语句之间的 GNU C asm("" : "+r"(a), "+r"(b)) 使编译器忘记它所知道的关于 ab,同时将它们保存在寄存器中。

我对 的回答更详细地介绍了使用 read-only "r" 寄存器约束来强制编译器在寄存器中实现结果,而不是 "+r" 使其假定值已被修改。

如果您很好地理解 GNU C 内联汇编,以您确切知道它们的作用的方式推出您自己的汇编可能会更容易。

我不确定它是否符合您的要求,但也许它只计算结果然后修复溢出的错误:

T average_of_3 (T a, T b, T c)
{
    T r = ((T) (a + b + c)) / 3;
    T o = (a > (T) ~b) + ((T) (a + b) > (T) (~c));
    if (o) r += ((T) 0x5555555555555555) << (o - 1);
    T rem = ((T) (a + b + c)) % 3;
    if (rem >= (3 - o)) ++r;
    return r;
}

[编辑] 这是我能想到的最好的 branch-and-compare-less 版本。在我的机器上,这个版本的吞吐量实际上比 njuffa 的代码略高。 __builtin_add_overflow(x, y, r) 被 gcc 和 clang 支持,returns 1 如果总和 x + y 溢出 *r 的类型,否则 0,所以计算o 相当于第一个版本中的可移植代码,但至少 gcc 使用内置函数生成更好的代码。

T average_of_3 (T a, T b, T c)
{
    T r = ((T) (a + b + c)) / 3;
    T rem = ((T) (a + b + c)) % 3;
    T dummy;
    T o = __builtin_add_overflow(a, b, &dummy) + __builtin_add_overflow((T) (a + b), c, &dummy);
    r += -((o - 1) & 0xaaaaaaaaaaaaaaab) ^ 0x5555555555555555;
    r += (rem + o + 1) >> 2;
    return r;
}

[Falk Hüffner 在评论中指出这个答案与 有相似之处。后来更仔细地查看他的代码,我确实发现了一些相似之处。然而,我在这里发布的是一个独立思考过程的产物,是我最初想法“在 div-mod 之前将三项减少为两项”的延续。我理解 Hüffner 的方法是不同的:“天真计算后进行更正”。]

我发现了一种比问题中的 CSA-technique 更好的方法,可以将除法和取模工作从三个操作数减少到两个操作数。首先,形成完整的 double-word 总和,然后分别对每一半应用除法和模 3,最后合并结果。由于最重要的一半只能取值 0、1 或 2,因此计算除以三的商和余数是微不足道的。而且,最终结果的组合也变得更简单。

与问题中的 non-simple 代码变体相比,这在我检查的所有平台上都实现了加速。编译器为模拟 double-word 添加生成的代码质量各不相同,但总体上令人满意。尽管如此,以 non-portable 方式对这部分进行编码可能是值得的,例如使用内联汇编。

T average_of_3_hilo (T a, T b, T c) 
{
    const T fives = (((T)(~(T)0)) / 3); // 0x5555...
    T avg, hi, lo, lo_div_3, lo_mod_3, hi_div_3, hi_mod_3; 
    /* compute the full sum a + b + c into the operand pair hi:lo */
    lo = a + b;
    hi = lo < a;
    lo = c + lo;
    hi = hi + (lo < c);
    /* determine quotient and remainder of each half separately */
    lo_div_3 = lo / 3;
    lo_mod_3 = (lo + lo_div_3) & 3;
    hi_div_3 = hi * fives;
    hi_mod_3 = hi;
    /* combine partial results into the division result for the full sum */
    avg = lo_div_3 + hi_div_3 + ((lo_mod_3 + hi_mod_3 + 1) / 4);
    return avg;
}

让我把我的帽子扔进戒指。在这里不做任何太棘手的事情,我 想想。

#include <stdint.h>

uint64_t average_of_three(uint64_t a, uint64_t b, uint64_t c) {
  uint64_t hi = (a >> 32) + (b >> 32) + (c >> 32);
  uint64_t lo = hi + (a & 0xffffffff) + (b & 0xffffffff) + (c & 0xffffffff);
  return 0x55555555 * hi + lo / 3;
}

在下面讨论了不同的分割之后,这里有一个以三个为代价节省乘法的版本 bitwise-ANDs:

T hi = (a >> 2) + (b >> 2) + (c >> 2);
T lo = (a & 3) + (b & 3) + (c & 3);
avg = hi + (hi + lo) / 3;

新答案,新思路。这个基于数学恒等式

floor((a+b+c)/3) = floor(x + (a+b+c - 3x)/3)

这何时适用于机器整数和无符号除法?
当差异不回绕时,即 0 ≤ a+b+c - 3x ≤ T_MAX.

这个 x 的定义很快,可以完成工作。

T avg3(T a, T b, T c) {
  T x = (a >> 2) + (b >> 2) + (c >> 2);
  return x + (a + b + c - 3 * x) / 3;
}

奇怪的是,除非我这样做,否则 ICC 会插入一个额外的否定:

T avg3(T a, T b, T c) {
  T x = (a >> 2) + (b >> 2) + (c >> 2);
  return x + (a + b + c - (x + x * 2)) / 3;
}

请注意 T 必须至少有五位宽。

如果T是两个平台字长,那么你可以通过省略x的低位字来节省一些双字操作。

延迟更差但吞吐量可能略高的替代版本?

T lo = a + b;
T hi = lo < b;
lo += c;
hi += lo < c;
T x = (hi << (sizeof(T) * CHAR_BIT - 2)) + (lo >> 2);
avg = x + (T)(lo - 3 * x) / 3;

我已经回答了您链接到的问题,所以我只回答与此不同的部分:性能。

如果您真的很关心性能,那么答案是:

( a + b + c ) / 3

既然您关心性能,那么您应该对所处理的数据大小有一个直觉。你不应该担心只有 3 个值的加法(乘法是另一回事)溢出,因为如果你的数据已经足够大可以使用你选择的数据类型的高位,你无论如何都有溢出的危险并且应该使用更大的整数类型。如果你在 uint64_t 上溢出,那么你真的应该问问自己,为什么你需要准确地计数到 18 quintillion,或许可以考虑使用 float 或 double。

现在,说了这么多,我将给你我的实际答复:没关系。这个问题在现实生活中不会出现,什么时候会出现, 性能无所谓。

如果您在 SIMD 中执行一百万次,这可能是一个真正的性能问题,因为在那里,您确实被激励使用更小宽度的整数,并且您可能需要最后一点余量,但那不是'不是你的问题。

GCC-11 的实验性构建将明显的原始函数编译成如下内容:

uint32_t avg3t (uint32_t a, uint32_t b, uint32_t c) {
    a += b;
    b = a < b;
    a += c;
    b += a < c;

    b = b + a;
    b += b < a;
    return (a - (b % 3)) * 0xaaaaaaab;
}

这与此处发布的其他一些答案相似。 欢迎对这些解决方案如何工作的任何解释 (不确定这里的网络礼仪)。