如何使用 numba 优化 numpy.packbits?

How do I optimise numpy.packbits with numba?

我正在尝试优化 numpy.packbits:

import numpy as np
from numba import njit, prange

@njit(parallel=True)
def _numba_pack(arr, div, su):
    for i in prange(div):
        s = 0
        for j in range(i*8, i*8+8):
            s = 2*s + arr[j]
        su[i] = s
        
def numba_packbits(arr):
    div, mod = np.divmod(arr.size, 8)
    su = np.zeros(div + (mod>0), dtype=np.uint8)
    _numba_pack(arr[:div*8], div, su)
    if mod > 0:
        su[-1] = sum(x*y for x,y in zip(arr[div*8:], (128, 64, 32, 16, 8, 4, 2, 1)))
    return su

>>> X = np.random.randint(2, size=99, dtype=bool)
>>> print(numba_packbits(X))
[ 75  24  79  61 209 189 203 187  47 226 170  61   0]

它看起来比 np.packbits(X) 慢 2 - 2.5 倍。 numpy 内部是如何实现的?这可以在 numba 中得到改进吗?

我致力于通过 conda install 安装的 numpy == 1.21.2numba == 0.53.1。我的平台是:

结果:

import benchit
from numpy import packbits
%matplotlib inline
benchit.setparams(rep=5)

sizes = [100000, 300000, 1000000, 3000000, 10000000, 30000000]
N = sizes[-1]
arr = np.random.randint(2, size=N, dtype=bool)
fns = [numba_packbits, packbits]

in_ = {s/1000000: (arr[:s], ) for s in sizes}
t = benchit.timings(fns, in_, multivar=True, input_name='Millions of bits')
t.plot(logx=True, figsize=(12, 6), fontsize=14)

更新

Jérôme 的回复:

@njit('void(bool_[::1], uint8[::1], int_)', inline='never')
def _numba_pack_x64_byJérôme(arr, su, pos):
    for i in range(64):
        j = i * 8
        su[i] = (arr[j]<<7)|(arr[j+1]<<6)|(arr[j+2]<<5)|(arr[j+3]<<4)|(arr[j+4]<<3)|(arr[j+5]<<2)|(arr[j+6]<<1)|arr[j+7]
       
@njit(parallel=True)
def _numba_pack_byJérôme(arr, div, su):
    for i in prange(div//64):
        _numba_pack_x64_byJérôme(arr[i*8:(i+64)*8], su[i:i+64], i)
    for i in range(div//64*64, div):
        j = i * 8
        su[i] = (arr[j]<<7)|(arr[j+1]<<6)|(arr[j+2]<<5)|(arr[j+3]<<4)|(arr[j+4]<<3)|(arr[j+5]<<2)|(arr[j+6]<<1)|arr[j+7]
        
def numba_packbits_byJérôme(arr):
    div, mod = np.divmod(arr.size, 8)
    su = np.zeros(div + (mod>0), dtype=np.uint8)
    _numba_pack_byJérôme(arr[:div*8], div, su)
    if mod > 0:
        su[-1] = sum(x*y for x,y in zip(arr[div*8:], (128, 64, 32, 16, 8, 4, 2, 1)))
    return su

用法:

>>> print(numba_packbits_byJérôme(X))
[ 75  24  79  61 209 189 203 187  47 226 170  61   0]

结果:

Numba 实施存在几个问题。其中之一是并行循环 破坏了 LLVM-Lite(Numba 使用的 JIT 编译器)中的恒定传播优化。这会导致关键信息(如数组跨度)无法传播,从而导致标量实现速度较慢,而不是 SIMD 实现,以及额外的不需要指令来计算偏移量。在 C 代码中也可以看到此类问题。 Numpy 添加了特定的宏,因此可以帮助编译器在工作维度的步幅实际为 1 时自动对代码进行矢量化(即使用 SIMD 指令)。

克服持续传播问题的一种解决方案是调用另一个 Numba 函数。此函数必须 而不是 内联。应该手动提供签名,以便编译器可以在编译时知道数组的步幅为 1 并生成更快的代码。最后,函数应该在固定大小的块上工作,因为函数调用很昂贵并且编译器可以对代码进行矢量化。 使用 shifts 展开循环也会产生更快的代码(尽管它更难看)。这是一个例子:

@njit('void(bool_[::1], uint8[::1], int_)', inline='never')
def _numba_pack_x64(arr, su, pos):
    for i in range(64):
        j = i * 8
        su[i] = (arr[j]<<7)|(arr[j+1]<<6)|(arr[j+2]<<5)|(arr[j+3]<<4)|(arr[j+4]<<3)|(arr[j+5]<<2)|(arr[j+6]<<1)|arr[j+7]

@njit('void(bool_[::1], int_, uint8[::1])', parallel=True)
def _numba_pack(arr, div, su):
    for i in prange(div//64):
        _numba_pack_x64(arr[i*8:(i+64)*8], su[i:i+64], i)
    for i in range(div//64*64, div):
        j = i * 8
        su[i] = (arr[j]<<7)|(arr[j+1]<<6)|(arr[j+2]<<5)|(arr[j+3]<<4)|(arr[j+4]<<3)|(arr[j+5]<<2)|(arr[j+6]<<1)|arr[j+7]

基准

以下是我的 6 核机器 (i5-9600KF) 的性能结果,输入十亿个随机项:

Initial Numba (seq):    189 ms  (x0.7)
Initial Numba (par):    141 ms  (x1.0)
Numpy (seq):             98 ms  (x1.4)
Optimized Numba (par):   35 ms  (x4.0)
Theoretical optimal:     27 ms  (x5.2)  [fully memory-bound case]

这个新实现比最初的并行实现快 4 倍,比 Numpy 快 3 倍


深入研究生成的汇编代码

设置parallel=False并将prange替换为range时,在我支持AVX-2的英特尔处理器上生成以下汇编代码:

.LBB0_7:
    vmovdqu 112(%rdx,%rax,8), %xmm1
    vmovdqa 384(%rsp), %xmm3
    vpshufb %xmm3, %xmm1, %xmm0
    vmovdqu 96(%rdx,%rax,8), %xmm2
    vpshufb %xmm3, %xmm2, %xmm3
    vpunpcklwd  %xmm0, %xmm3, %xmm3
    vmovdqu 80(%rdx,%rax,8), %xmm15
    vmovdqa 368(%rsp), %xmm5
    vpshufb %xmm5, %xmm15, %xmm4
    vmovdqu 64(%rdx,%rax,8), %xmm0
    [...] <------------------------------  ~180 other instructions discarded
    vpcmpeqb    %xmm3, %xmm11, %xmm2
    vpandn  %xmm8, %xmm2, %xmm2
    vpor    %xmm2, %xmm1, %xmm1
    vpcmpeqb    %xmm3, %xmm0, %xmm0
    vpaddb  %xmm0, %xmm1, %xmm0
    vpsubb  %xmm4, %xmm0, %xmm0
    vmovdqu %xmm0, (%r11,%rax)
    addq    , %rax
    cmpq    %rax, %rsi
    jne .LBB0_7

代码不是很好因为它使用了很多不需要的指令(比如 SIMD 比较指令可能是由于布尔类型的隐式转换),很多寄存器是临时存储的(寄存器溢出)并且它使用 128 位我的机器支持 AVX 向量而不是 256 位 AVX 向量。也就是说,代码是矢量化的,每个循环迭代一次写入 16 个字节,没有任何条件分支(循环之一除外),因此结果性能还不错。

事实上,Numpy 代码更小且更高效。这就是为什么它比我机器上具有大输入的顺序 Numba 代码快约 2 倍的原因。这是热汇编循环:

4e8:
    mov      (%rdx,%rax,8),%rcx
    bswap    %rcx
    mov      %rcx,0x20(%rsp)
    mov      0x8(%rdx,%rax,8),%rcx
    add      [=13=]x2,%rax
    movq     0x20(%rsp),%xmm0
    bswap    %rcx
    mov      %rcx,0x20(%rsp)
    movhps   0x20(%rsp),%xmm0
    pcmpeqb  %xmm1,%xmm0
    pcmpeqb  %xmm1,%xmm0
    pmovmskb %xmm0,%ecx
    mov      %cl,(%rsi)
    movzbl   %ch,%ecx
    mov      %cl,(%rsi,%r13,1)
    add      %r9,%rsi
    cmp      %rax,%r8
    jg       4e8

它以 8 字节为单位读取值,并使用 128 位 SSE 指令部分计算它们。每次迭代写入 2 个字节。话虽如此,它也不是最优的,因为没有使用 256 位 SIMD 指令,我认为代码可以进一步优化。

使用初始并行代码时,这里是热循环的汇编代码:

.LBB3_4:
     movq %r9, %rax
     leaq (%r10,%r14), %r9
     movq %r15, %rsi
     sarq , %rsi
     andq %rdx, %rsi
     addq %r11, %rsi
     cmpb [=14=], (%r14,%rsi)
     setne     %cl
     addb %cl, %cl
     [...] <---------------  56 instructions (with few 256-bit AVX ones)
     orb  %bl, %cl
     orb  %al, %cl
     orb  %dl, %cl
     movq %rbp, %rdx
     movb %cl, (%r8,%r15)
     incq %r15
     decq %rdi
     addq , %r14
     cmpq , %rdi
     jg   .LBB3_4

上面的代码主要是没有向量化,效率很低。它使用大量指令(包括非常慢的指令,如 setne/cmovlq/cmpb 来执行许多条件存储)​​每次迭代仅一次写入 1 个字节。对于相同数量的写入字节,Numpy 执行的指令少了大约 8 倍。此代码的低效率通过使用多线程得到缓解。最后,并行版本在具有许多内核(例如 >= 6)的机器上可以更快一些。

本答案开头提供的改进实现生成了类似于上述顺序实现的代码,但使用了多线程(目前还远未达到最佳,但更胜一筹)。