有没有branchless方法可以快速找到两个双精度浮点值的min/max?

Is there a branchless method to quickly find the min/max of two double-precision floating-point values?

我有两个双打,ab,它们都在 [0,1] 中。出于性能原因,我希望 ab 的 min/max 没有分支。

鉴于ab都是正数,并且小于1,有没有一种有效的方法来获得两者的min/max?理想情况下,我不希望有分支。

是的,有一种方法可以在没有任何分支的情况下计算两个 double 的最大值或最小值。这样做的 C++ 代码如下所示:

#include <algorithm>

double FindMinimum(double a, double b)
{
    return std::min(a, b);
}

double FindMaximum(double a, double b)
{
    return std::max(a, b);
}

我打赌你以前见过这个。以免你不相信这是无分支的,check out the disassembly:

FindMinimum(double, double):
    minsd   xmm1, xmm0
    movapd  xmm0, xmm1
    ret

FindMaximum(double, double):
    maxsd   xmm1, xmm0
    movapd  xmm0, xmm1
    ret

这就是您从所有面向 x86 的流行编译器中获得的。使用了 SSE2 指令集,特别是 minsd/maxsd 指令,它无分支地评估两个双精度浮点值的 minimum/maximum 值。

所有 64 位 x86 处理器都支持 SSE2; it is required by the AMD64 extensions. Even most x86 processors without 64-bit support SSE2. It was released in 2000. You'd have to go back a long way to find a processor that didn't support SSE2. But what about if you did? Well, even there, you get branchless code on most popular compilers:

FindMinimum(double, double):
    fld      QWORD PTR [esp + 12]
    fld      QWORD PTR [esp + 4]
    fucomi   st(1)
    fcmovnbe st(0), st(1)
    fstp     st(1)
    ret

FindMaximum(double, double):
    fld      QWORD PTR [esp + 4]
    fld      QWORD PTR [esp + 12]
    fucomi   st(1)
    fxch     st(1)
    fcmovnbe st(0), st(1)
    fstp     st(1)
    ret

fucomi 指令执行比较,设置标志,然后 fcmovnbe 指令根据这些标志的值执行条件移动。这完全是无分支的,并且依赖于 1995 年随 Pentium Pro 引入 x86 ISA 的指令,自 Pentium II 以来所有 x86 芯片都支持该指令。

这里唯一不会生成无分支代码的编译器是 MSVC,因为 it doesn't take advantage of the FCMOVxx instruction。相反,你得到:

double FindMinimum(double, double) PROC
    fld     QWORD PTR [a]
    fld     QWORD PTR [b]
    fcom    st(1)            ; compare "b" to "a"
    fnstsw  ax               ; transfer FPU status word to AX register
    test    ah, 5            ; check C0 and C2 flags
    jp      Alt
    fstp    st(1)            ; return "b"
    ret
Alt:
    fstp    st(0)            ; return "a"
    ret
double FindMinimum(double, double) ENDP

double FindMaximum(double, double) PROC
    fld     QWORD PTR [b]
    fld     QWORD PTR [a]
    fcom    st(1)            ; compare "b" to "a"
    fnstsw  ax               ; transfer FPU status word to AX register
    test    ah, 5            ; check C0 and C2 flags
    jp      Alt
    fstp    st(0)            ; return "b"
    ret
Alt:
    fstp    st(1)            ; return "a"
    ret
double FindMaximum(double, double) ENDP

注意分支 JP 指令(如果设置了奇偶校验位则跳转)。 FCOM 指令用于进行比较,它是基本 x87 FPU 指令集的一部分。不幸的是,这会在 FPU 状态字中设置标志,因此为了根据这些标志进行分支,需要将它们提取出来。这就是 FNSTSW 指令的目的,它将 x87 FPU 状态字存储到通用 AX 寄存器(它也可以存储到内存,但是......为什么?)。然后代码 TESTs 适当的位,并相应地分支以确保正确的值被 returned。除了分支之外,检索 FPU 状态字也会比较慢。这就是 Pentium Pro 引入 FCOM 指令的原因。

但是,不太可能您可以通过使用位运算来确定 min/max 来提高任何此代码的速度。有两个基本原因:

  1. 唯一生成低效代码的编译器是 MSVC,没有好的方法可以强制它生成您想要的指令。尽管 MSVC 支持 32 位 x86 目标的内联汇编,it is a fool's errand when seeking performance improvements。我也会引用我自己的话:

    Inline assembly disrupts the optimizer in rather significant ways, so unless you're writing significant swaths of code in inline assembly, there is unlikely to be a substantial net performance gain. Furthermore, Microsoft's inline assembly syntax is extremely limited. It trades flexibility for simplicity in a big way. In particular, there is no way to specify input values, so you're stuck loading the input from memory into a register, and the caller is forced to spill the input from a register to memory in preparation. This creates a phenomenon I like to call "a whole lotta shufflin' goin' on", or for short, "slow code". You don't drop to inline assembly in cases where slow code is acceptable. Thus, it is always preferable (at least on MSVC) to figure out how to write C/C++ source code that persuades the compiler to emit the object code you want. Even if you can only get close to the ideal output, that's still considerably better than the penalty you pay for using inline assembly.

  2. 为了访问浮点值的原始位,您必须进行域转换,从浮点到整数,然后再返回到浮点.这很慢,尤其是 没有 SSE2,因为从 x87 FPU 获取值到 ALU 中的通用整数寄存器的唯一方法是间接通过内存。

如果你无论如何都想采用这种策略——比如说,对其进行基准测试——你可以利用这样一个事实,即浮点值是根据它们的 IEEE 754 表示形式按字典顺序排序的,符号除外少量。因此,由于您假设这两个值都是正数:

FindMinimumOfTwoPositiveDoubles(double a, double b):
    mov   rax, QWORD PTR [a]
    mov   rdx, QWORD PTR [b]
    sub   rax, rdx              ; subtract bitwise representation of the two values
    shr   rax, 63               ; isolate the sign bit to see if the result was negative
    ret

FindMaximumOfTwoPositiveDoubles(double a, double b):
    mov   rax, QWORD PTR [b]    ; \ reverse order of parameters
    mov   rdx, QWORD PTR [a]    ; /  for the SUB operation
    sub   rax, rdx
    shr   rax, 63
    ret

或者,为了避免内联汇编:

bool FindMinimumOfTwoPositiveDoubles(double a, double b)
{
    static_assert(sizeof(a) == sizeof(uint64_t),
                  "A double must be the same size as a uint64_t for this bit manipulation to work.");
    const uint64_t aBits = *(reinterpret_cast<uint64_t*>(&a));
    const uint64_t bBits = *(reinterpret_cast<uint64_t*>(&b));
    return ((aBits - bBits) >> ((sizeof(uint64_t) * CHAR_BIT) - 1));
}

bool FindMaximumOfTwoPositiveDoubles(double a, double b)
{
    static_assert(sizeof(a) == sizeof(uint64_t),
                  "A double must be the same size as a uint64_t for this bit manipulation to work.");
    const uint64_t aBits = *(reinterpret_cast<uint64_t*>(&a));
    const uint64_t bBits = *(reinterpret_cast<uint64_t*>(&b));
    return ((bBits - aBits) >> ((sizeof(uint64_t) * CHAR_BIT) - 1));
}

请注意,此实现有 严重 警告。特别是,如果两个浮点值具有不同的符号,或者两个值都是负数,它将中断。如果两个值都是负数,那么可以修改代码以翻转它们的符号,进行比较,然后 return 相反的值。为了处理两个值具有不同符号的情况,可以添加代码来检查符号位。

    // ...

    // Enforce two's-complement lexicographic ordering.
    if (aBits < 0)
    {
        aBits = ((1 << ((sizeof(uint64_t) * CHAR_BIT) - 1)) - aBits);
    }
    if (bBits < 0)
    {
        bBits = ((1 << ((sizeof(uint64_t) * CHAR_BIT) - 1)) - bBits);
    }

    // ...

处理负零也会是个问题。 IEEE 754 表示 +0.0 等于 −0.0,因此您的比较函数必须决定是否要将这些值视为不同的值,或者向比较例程添加特殊代码以确保将负零和正零视为等效。

添加所有这些特殊情况代码 肯定会 将性能降低到我们将通过简单的浮点比较收支平衡的程度,并且很可能最终是较慢。