为什么类型列表的元素访问比使用 Numba 的数组慢得多?

Why is element access for typed lists so much slower than for arrays with Numba?

我想知道为什么在使用 Numba 时,类型化列表访问元素的速度比 NumPy 数组慢得多。我有下面显示的这个最小示例。我随机生成索引以防止在幕后进行任何编译器优化。似乎在校正生成所有这些随机数所需的时间后,在 NumPy 数组的情况下,每个元素访问几乎是瞬间发生的 (<1ns),而对于类型化列表,每次访问最多需要 100 ns。也许人们应该期望类型化列表会慢一点,但这对我来说似乎差异太大,如果需要大量访问列表元素,可能会显着降低代码速度。不幸的是,我不是计算机科学专家,所以我可能缺乏一些关于访问操作如何在这两种不同数据结构上工作的基本背景知识。那么,您知道为什么访问速度会有如此显着的差异吗?


import numpy as np
import numba as nb

@nb.njit
def only_rand(N):
    
    for _ in range(10000):
        
        i = np.random.randint(N)
        j = np.random.randint(N)

@nb.njit
def foo(pos, N):
    
    for _ in range(10000):
        
        i = np.random.randint(N)
        j = np.random.randint(N)
        
        dx = pos[i][0] - pos[j][0]
        dy = pos[i][1] - pos[j][1]
        dz = pos[i][2] - pos[j][2]
        

N = 100

Array = np.random.rand(N,3)
List = nb.typed.List(Array)

print('Random number generation:')
%timeit only_rand(N)
print('Numpy Array:')
%timeit foo(Array, N)
print('Typed List:')
%timeit foo(List, N)

输出:

Random number generation:
133 µs ± 4.92 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
Numpy Array:
132 µs ± 881 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
Typed List:
947 µs ± 74 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)

一个问题是基准测试存在缺陷。事实上,Numba JIT 编译器可以(部分地)看到你的计算大部分是无用的,因为它 大部分 没有 计算可见影响 :[= 13=、dydz 不会被读取,因此它们的计算(例如 pos[i][0] - pos[j][0])可以简单地忽略。乍一看,ij 似乎也是如此,但事实并非如此:np.random.randint 修改内部种子导致 副作用 .这种副作用迫使编译器仍然计算部分循环。

然而,除了以上几点,一旦基准修复,基于列表的实现确实更慢。它来自临时列表的引用计数。并且汇编代码JIT优化得不太好(列表往往会生成更复杂的代码,更难优化)。


深入分析:

要看JIT优化代码,可以把N的值调大很多。这是我机器上的时间:

With N=100:
  Random number generation:
  109 µs ± 13.1 µs per loop (mean ± std. dev. of 7 runs, 2000 loops each)
  Numpy Array:
  113 µs ± 22.7 µs per loop (mean ± std. dev. of 7 runs, 2000 loops each)
  Typed List:
  806 µs ± 197 µs per loop (mean ± std. dev. of 7 runs, 250 loops each)

With N=1_000_000:
  Random number generation:
  64.7 µs ± 13.3 µs per loop (mean ± std. dev. of 7 runs, 2000 loops each)
  Numpy Array:
  68.6 µs ± 23.1 µs per loop (mean ± std. dev. of 7 runs, 2000 loops each)
  Typed List:
  804 µs ± 215 µs per loop (mean ± std. dev. of 7 runs, 250 loops each)

With N=10_000_000:
  Random number generation:
  185 µs ± 13.7 µs per loop (mean ± std. dev. of 7 runs, 2000 loops each)
  Numpy Array:
  190 µs ± 23.2 µs per loop (mean ± std. dev. of 7 runs, 2000 loops each)
  Typed List:
  839 µs ± 200 µs per loop (mean ± std. dev. of 7 runs, 250 loops each)

请注意,时间不是很依赖 N

这两种实现的汇编代码相当庞大,但可以看出主循环是相似的,并且在两种情况下都包含对 numba_rnd_shuffle 的调用,这些调用未经过优化(由于 [= 的副作用) 19=]).这是一个例子:

.LBB0_20: <----------\
        cmpl    4, %eax
        jae     .LBB0_21
.LBB0_22: <----------\
        movl    %eax, %ecx
        movl    4(%rsi,%rcx,4), %ebp
        leal    1(%rax), %ecx
        movl    %ecx, (%rsi)
        movl    %ebp, %edx
        shrl    , %edx
        xorl    %ebp, %edx
        movl    %edx, %ebp
        shll    , %ebp
        andl    $-1658038656, %ebp
        xorl    %edx, %ebp
        movl    %ebp, %edx
        shll    , %edx
        andl    $-272236544, %edx
        xorl    %ebp, %edx
        movl    %edx, %ebp
        shrl    , %ebp
        xorl    %edx, %ebp
        andl    %edi, %ebp
        cmpl    3, %eax
        jae     .LBB0_23
.LBB0_24: <----------\
        movl    %ecx, %eax
        movl    4(%rsi,%rax,4), %eax
        incl    %ecx
        movl    %ecx, (%rsi)
        movl    %eax, %edx
        shrl    , %edx
        xorl    %eax, %edx
        movl    %edx, %eax
        shll    , %eax
        andl    $-1658038656, %eax
        xorl    %edx, %eax
        movl    %eax, %edx
        shll    , %edx
        andl    $-272236544, %edx
        xorl    %eax, %edx
        movl    %edx, %eax
        shrl    , %eax
        xorl    %edx, %eax
        shlq    , %rbp
        orq     %rax, %rbp
        movl    %ecx, %eax
        cmpq    %r14, %rbp
        jge     .LBB0_20 ---------->
        jmp     .LBB0_12
.LBB0_21:
        movq    %rsi, %rcx
        movabsq $numba_rnd_shuffle, %rax
        callq   *%rax
        movl    [=11=], (%rsi)
        xorl    %eax, %eax
        jmp     .LBB0_22 ---------->
.LBB0_23:
        movq    %rsi, %rcx
        movabsq $numba_rnd_shuffle, %rax
        callq   *%rax
        movl    [=11=], (%rsi)
        xorl    %ecx, %ecx
        jmp     .LBB0_24 ---------->
        .p2align        4, 0x90

问题是在每次迭代结束时,以下汇编代码重复 6 次:

        movabsq $numba_list_size_address, %rdi

        movq    %r13, %rcx
        movabsq $NRT_incref, %rax
        callq   *%rax                  # NRT_incref(ptrVar);

        movq    %r15, %rcx
        callq   *%rdi                  # tmp1 = numba_list_size_address(listVar);

        movq    %rbp, %r12
        sarq    , %r12
        movq    (%rax), %rbx           # tmp2 = fancy_operation(*tmp1)
        andq    %r12, %rbx
        addq    %rbp, %rbx
        js      .LBB0_34               # Conditional goto to the end (overflow check?)

        movq    %r15, %rcx
        callq   *%rdi                  # tmp3 = numba_list_size_address(listVar);

        movq    %r15, %rdi
        movq    (%rax), %r15
        movq    %r13, %rcx
        movabsq $NRT_decref, %rax
        callq   *%rax                  # NRT_decref(ptrVar);

        cmpq    %r15, %rbx             # if(*tmp3 >= tmp2)
        jge     .LBB0_33               #     goto end;

        movq    %rdi, %rcx
        movabsq $numba_list_base_ptr, %rax
        callq   *%rax                  # numba_list_base_ptr(listVar);

可以看到引用计数调用以及与列表相关的函数。这部分汇编代码来自表达式 pos[i]pos[j]。 JIT 未优化列表对象的引用计数。相关检查似乎也是如此。

我猜这是因为在这种情况下无法优化 Numba 函数调用,或者 JIT 认为它不够昂贵。列表相关函数的代码可以找到here。我发现 JIT 没有优化与列表相关的函数调用很奇怪,因为它们被 Numba 标记为 alwaysinlinereadonly...无论如何,我认为这是一个 missed优化,可以改进。

我向 Numba 开发人员提交了一个问题 here