Python 中的快速按位求和

Fast Bitwise Sum in Python

是否有一种有效的方法来计算 Python 中数组中每一列的总和?

示例(Python 3.7 和 Numpy 1.20.1):

  1. 创建值为 0 或 1 的 numpy 数组
import numpy as np

array = np.array(
    [
     [1, 0, 1],   
     [1, 1, 1], 
     [0, 0, 1],    
    ]
)
  1. 将大小压缩 np.packbits
pack_array = np.packbits(array, axis=1)
  1. 预期结果:没有np.unpackbits的每个位置(列)的位总和与array.sum(axis=0)相同:
array([2, 1, 3])

我发现解决方案非常慢:

dim = array.shape[1]
candidates = np.zeros((dim, dim)).astype(int)
np.fill_diagonal(candidates, 1)

pack_candidates = np.packbits(candidates, axis=1)

np.apply_along_axis(lambda c:np.sum((np.bitwise_and(pack_array, c) == c).all(axis=1)), 1, pack_candidates)

在numpy中似乎没有比numpy.unpackbits更好的选择了。

为了更清楚,我们再举个例子:

array = np.array([[1, 0, 1, 0, 1, 1, 1, 0, 1], 
                  [1, 1, 1, 1, 1, 1, 1, 1, 1], 
                  [0, 0, 1, 0, 0, 0, 0, 0, 0]])
pack_array = np.packbits(array, axis=1)
dim = array.shape[1]

现在,pack_array是这样计算的:

[[1,0,1,0,1,1,1,0], [1,0,0,0,0,0,0,0]] -> [174, 128]
[[1,1,1,1,1,1,1,1], [1,0,0,0,0,0,0,0]] -> [255, 128]
[[0,0,1,0,0,0,0,0], [0,0,0,0,0,0,0,0]] -> [32, 0]

我测试了各种算法,解包位似乎是最快的:

def numpy_sumbits(pack_array, dim):
    out = np.unpackbits(pack_array, axis=1, count=dim)
    arr = np.sum(out, axis=0)
    return arr

def manual_sumbits(pack_array, dim):
    arr = pack_array.copy()
    out = np.empty((dim//8+1) * 8, dtype=int)
    for i in range(8):
        out[7 - i%8::8] = np.sum(arr % 2, axis=0)
        arr = arr // 2
    return out[:dim]

def numpy_sumshifts(pack_array, dim):
    res = (pack_array.reshape(pack_array.size, -1) >> np.arange(8)) % 2
    res = res.reshape(*pack_array.shape, 8)
    return np.sum(res, axis=0)[:,::-1].ravel()[:dim]

print(numpy_unpackbits(pack_array, dim))
print(manual_unpackbits(pack_array, dim))
print(numpy_sumshifts(pack_array, dim))
>>>
[2 1 3 1 2 2 2 1 2]
[2 1 3 1 2 2 2 1 2]
[2 1 3 1 2 2 2 1 2]

%%timeit
numpy_sumbits(pack_array, dim)
>>> 3.49 ms ± 57.3 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

%%timeit
manual_sumbits(pack_array, dim)
>>> 10 ms ± 22.4 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

%%timeit
numpy_sumshifts(pack_array, dim)
>>> 20.1 ms ± 97.9 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)

如果输入数组很大,使用 np.unpackbits 可能会出现问题,因为结果数组可能 太大而无法放入 RAM,即使它确实适合RAM,这远非高效,因为必须从(慢速)主内存写入和读取巨大的数组。同样的事情适用于 CPU 缓存: 较小的数组通常可以更快地计算 。此外,np.unpackbits对于小数组有相当大的开销。

AFAIK,在使用少量 RAM(即使用 np.unpackbits,如 @mathfux 所指出的)时,在 Numpy 中不可能非常有效地执行此操作。但是,Numba 可用于加速此计算,尤其是对于小型数组。这是代码:

@nb.njit('int32[::1](uint8[:,::1], int_)')
def bitSum(packed, m):
    n = packed.shape[0]
    assert packed.shape[1]*8-7 <= m <= packed.shape[1]*8
    res = np.zeros(m, dtype=np.int32)
    for i in range(n):
        for j in range(m):
            res[j] += bool(packed[i, j//8] & (128>>(j%8)))
    return res

如果您想要更快的实施,您可以通过处理固定大小的图块来优化代码。但是,这也使代码更加复杂。这是结果代码:

@nb.njit('int32[::1](uint8[:,::1], int_)')
def bitSumOpt(packed, m):
    n = packed.shape[0]
    assert packed.shape[1]*8-7 <= m <= packed.shape[1]*8
    res = np.zeros(m, dtype=np.int32)
    for i in range(0, n, 4):
        for j in range(0, m, 8):
            if i+3 < n and j+7 < m:
                # Highly-optimized 4x8 tile computation
                k = j//8
                b0, b1, b2, b3 = packed[i,k], packed[i+1,k], packed[i+2,k], packed[i+3,k]
                for j2 in range(8):
                    shift = 7 - j2
                    mask = 1 << shift
                    res[j+j2] += ((b0 & mask) + (b1 & mask) + (b2 & mask) + (b3 & mask)) >> shift
            else:
                # Slow fallback computation
                for i2 in range(i, min(i+4, n)):
                    for j2 in range(j, min(j+8, m)):
                        res[j2] += bool(packed[i2, j2//8] & (128>>(j2%8)))
    return res

以下是我机器上的性能结果:

On the example array:
Initial code:    62.90 us   (x1)
numpy_sumbits:    4.37 us   (x14)
bitSumOpt:        0.84 us   (x75)
bitSum:           0.77 us   (x82)

On a random 2000x2000 array:
Initial code:  1203.8  ms   (x1)
numpy_sumbits:    3.9  ms   (x308)
bitSum:           2.7  ms   (x446)
bitSumOpt:        1.5  ms   (x802)

Numba 实现的内存占用也好得多(至少小 8 倍)。