为什么这个 Rust 代码中没有分支预测失败惩罚?

Why isn't there a branch prediction failure penalty in this Rust code?

我写了这个非常简单的 Rust 函数:

fn iterate(nums: &Box<[i32]>) -> i32 {
    let mut total = 0;
    let len = nums.len();
    for i in 0..len {
        if nums[i] > 0 {
            total += nums[i];
        } else {
            total -= nums[i];
        }
    }

    total
}

我已经编写了一个基本基准测试,它使用有序数组和随机数组调用方法:

fn criterion_benchmark(c: &mut Criterion) {
    const SIZE: i32 = 1024 * 1024;

    let mut group = c.benchmark_group("Branch Prediction");

    // setup benchmarking for an ordered array
    let mut ordered_nums: Vec<i32> = vec![];
    for i in 0..SIZE {
        ordered_nums.push(i - SIZE/2);
    }
    let ordered_nums = ordered_nums.into_boxed_slice();
    group.bench_function("ordered", |b| b.iter(|| iterate(&ordered_nums)));

    // setup benchmarking for a shuffled array
    let mut shuffled_nums: Vec<i32> = vec![];
    for i in 0..SIZE {
        shuffled_nums.push(i - SIZE/2);
    }
    let mut rng = thread_rng();
    let mut shuffled_nums = shuffled_nums.into_boxed_slice();
    shuffled_nums.shuffle(&mut rng);
    group.bench_function("shuffled", |b| b.iter(|| iterate(&shuffled_nums)));

    group.finish();
}

criterion_group!(benches, criterion_benchmark);
criterion_main!(benches);

令我惊讶的是,这两个基准测试具有几乎完全相同的运行时间,而 Java 中的一个类似基准测试显示了两者之间的明显差异,这可能是由于随机情况下的分支预测失败。

我看到有条件移动指令的提及,但是如果我 otool -tv 可执行文件(我在 Mac 上 运行),我没有看到任何iterate 方法输出。

谁能阐明为什么在 Rust 中有序和无序情况之间没有明显的性能差异?

总结:LLVM 能够通过使用 cmov 指令或 SIMD 指令的真正巧妙组合来 remove/hide 分支。


我用 Godbolt view the full assembly(用 -C opt-level=3)。我将在下面解释组装的重要部分。

它是这样开始的:

        mov     r9, qword ptr [rdi + 8]         ; r9 = nums.len()
        test    r9, r9                          ; if len == 0
        je      .LBB0_1                         ;     goto LBB0_1
        mov     rdx, qword ptr [rdi]            ; rdx = base pointer (first element)
        cmp     r9, 7                           ; if len > 7
        ja      .LBB0_5                         ;     goto LBB0_5
        xor     eax, eax                        ; eax = 0
        xor     esi, esi                        ; esi = 0
        jmp     .LBB0_4                         ; goto LBB0_4

.LBB0_1:
        xor     eax, eax                        ; return 0
        ret

这里,函数区分了3种不同的"states":

  • 切片为空 → return 0 立即
  • 切片长度 ≤ 7 → 使用标准顺序算法 (LBB0_4)
  • 切片长度 > 7 → 使用 SIMD 算法 (LBB0_5)

那么让我们来看看这两种不同的算法吧!


标准顺序算法

记住 rsi (esi) 和 rax (eax) 被设置为 0 并且 rdx 是指向数据的基指针.

.LBB0_4:
        mov     ecx, dword ptr [rdx + 4*rsi]    ; ecx = nums[rsi]
        add     rsi, 1                          ; rsi += 1
        mov     edi, ecx                        ; edi = ecx
        neg     edi                             ; edi = -edi
        cmovl   edi, ecx                        ; if ecx >= 0 { edi = ecx }
        add     eax, edi                        ; eax += edi
        cmp     r9, rsi                         ; if rsi != len
        jne     .LBB0_4                         ;     goto LBB0_4
        ret                                     ; return eax

这是一个遍历 num 的所有元素的简单循环。不过在循环体中有一个小技巧:从原始元素 ecx 开始,一个取反值存储在 edi 中。通过使用cmovledi被原始值覆盖如果原始值为正。这意味着 edi 将始终为正(即包含原始元素的绝对值)。然后它被添加到 eax (最后是 returned)。

所以你的 if 分支隐藏在 cmov 指令中。正如您在 this benchmark 中所见,执行 cmov 指令所需的时间与条件的概率无关。这是一个非常惊人的指令!


SIMD算法

SIMD 版本包含很多指令,我不会在此处完整粘贴。主循环一次处理 16 个整数!

        movdqu  xmm5, xmmword ptr [rdx + 4*rdi]
        movdqu  xmm3, xmmword ptr [rdx + 4*rdi + 16]
        movdqu  xmm0, xmmword ptr [rdx + 4*rdi + 32]
        movdqu  xmm1, xmmword ptr [rdx + 4*rdi + 48]

它们从内存加载到寄存器 xmm0xmm1xmm3xmm5。这些寄存器中的每一个都包含四个 32 位值,但为了更容易理解,只需想象每个寄存器只包含一个值。以下所有指令分别对这些 SIMD 寄存器的每个值进行操作,因此心智模型很好!我下面的解释也听起来好像 xmm 寄存器只包含一个值。

主要技巧现在在以下说明中(处理 xmm5):

        movdqa  xmm6, xmm5      ; xmm6 = xmm5 (make a copy)
        psrad   xmm6, 31        ; logical right shift 31 bits (see below)
        paddd   xmm5, xmm6      ; xmm5 += xmm6
        pxor    xmm5, xmm6      ; xmm5 ^= xmm6

逻辑右移用符号位的值填充"empty high-order bits"(左边的"shifted in")。通过移动 31,我们最终 每个位置只有符号位 !所以任何正数都会变成 32 个零,任何负数都会变成 32 个一。所以 xmm6 现在是 000...000(如果 xmm5 为正数)或 111...111(如果 xmm5 为负数)。

接下来将这个人工 xmm6 添加到 xmm5。如果 xmm5 为正数,则 xmm6 为 0,因此添加它不会改变 xmm5。但是,如果 xmm5 为负,我们添加 111...111,这相当于减去 1。最后,我们将 xmm5xmm6 异或。同样,如果 xmm5 一开始是正数,我们与 000...000 异或,这没有效果。如果 xmm5 一开始是负数,我们与 111...111 异或,这意味着我们翻转所有位。所以对于这两种情况:

  • 如果元素为正,我们什么都不改变(addxor 没有任何影响)
  • 如果元素为负,我们减去 1 并翻转所有位。 这是一个二进制补码取反!

所以用这4条指令我们计算出了xmm5的绝对值!同样,由于这个小技巧,这里没有分支。请记住 xmm5 实际上包含 4 个整数,所以速度非常快!

这个绝对值现在被添加到一个累加器,并且对其他三个包含切片值的 xmm 寄存器也进行了同样的操作。 (我们不会详细讨论剩余的代码。)


带有 AVX2 的 SIMD

如果我们允许 LLVM 发出 AVX2 指令(通过 -C target-feature=+avx2),它甚至可以使用 pabsd 指令而不是四个 "hacky" 指令:

vpabsd  ymm2, ymmword ptr [rdx + 4*rdi]

它直接从内存中加载值,计算绝对值并在一条指令中将其存储在 ymm2 中!请记住 ymm 寄存器是 xmm 寄存器的两倍(适合八个 32 位值)!