从 cython 直接访问位数组

Direct access to bitarray from cython

我可以使用切片语法访问 bitarray 位..

b = bitarray(10)
b[5]

如何直接访问元素?

类似于我直接访问array个元素的方式:

ary.data.as_ints[5]

而不是:

ary[5]

我问是因为当我在某些情况下为 array 尝试这个时,我得到了 20-30 倍的加速。


我找到了我需要访问的内容,但不知道如何访问!

bitarray.h

查看 getbit() 和 setbit()。

如何从 Cython 访问它们?


当前速度

Shape: (10000, 10000)
VSize: 100.00Mil
Mem: 12207.03kb, 11.92mb
                
----------------------
sa[5,5]=1
108 ns +- 0.451 ns per loop (mean +- std. dev. of 7 runs, 10000000 loops each)
sa[5,5]
146 ns +- 37.1 ns per loop (mean +- std. dev. of 7 runs, 10000000 loops each)
sa[100:120,100:120]
34.8 µs +- 7.39 µs per loop (mean +- std. dev. of 7 runs, 10000 loops each)
sa[:100,:100]
614 µs +- 135 µs per loop (mean +- std. dev. of 7 runs, 1000 loops each)
sa[[0,1,2],[0,1,2]]
1.11 µs +- 301 ns per loop (mean +- std. dev. of 7 runs, 1000000 loops each)
sa.sum()
6.74 ms +- 1.82 ms per loop (mean +- std. dev. of 7 runs, 100 loops each)
sa.sum(axis=0)
9.92 ms +- 2.49 ms per loop (mean +- std. dev. of 7 runs, 100 loops each)
sa.sum(axis=1)
646 ms +- 42.4 ms per loop (mean +- std. dev. of 7 runs, 1 loop each)
sa.mean()
5.17 ms +- 160 µs per loop (mean +- std. dev. of 7 runs, 100 loops each)
sa.mean(axis=0)
12.8 ms +- 2.5 ms per loop (mean +- std. dev. of 7 runs, 100 loops each)
sa.mean(axis=1)
730 ms +- 25.1 ms per loop (mean +- std. dev. of 7 runs, 1 loop each)
sa[[9269, 5484, 2001, 8881, 30, 9567, 7654, 3034, 4901, 552],:],
6.87 ms +- 1.2 ms per loop (mean +- std. dev. of 7 runs, 100 loops each)
sa[:,[1417, 157, 9793, 1300, 2339, 2439, 2925, 3980, 4550, 5100]],
9.88 ms +- 1.56 ms per loop (mean +- std. dev. of 7 runs, 100 loops each)
sa[[9269, 5484, 2001, 8881, 30, 9567, 7654, 3034, 4901, 552],[1417, 157, 9793, 1300, 2339, 2439, 2925, 3980, 4550, 5100]],
6.59 µs +- 1.78 µs per loop (mean +- std. dev. of 7 runs, 100000 loops each)
sa[[9269, 5484, 2001, 8881, 30, 9567, 7654, 3034, 4901, 552],:].sum(axis=1),
466 ms +- 121 ms per loop (mean +- std. dev. of 7 runs, 1 loop each)

我建议使用类型化内存视图(它允许您访问 8 位的块),然后使用 bitwise-and 操作来访问这些位。这绝对是 Cython 最简单、最“原生”的方式。

cimport cython

@cython.boundscheck(False)
@cython.wraparound(False)
def sum_bits1(ba):
    cdef unsigned char[::1] ba_view = ba
    cdef int count = 0
    cdef Py_ssize_t n
    cdef unsigned char val
    for n in range(len(ba)):
        idx = n//8
        subidx = 1 << (n % 8)
        val = ba_view[idx] & subidx
        if val:
            count += 1
    return count

如果你想使用在“bitarray.h”中定义的getbitsetbit函数,那么你只需将它们定义为cdef extern函数。您需要找到“bitarray.h”的路径。它可能在您本地的 pip 安装目录中的某个地方。我已将完整路径放入文件中,但更好的解决方案是在 setup.py.

中指定包含路径
cdef extern from "<path to home>/.local/lib/python3.8/site-packages/bitarray/bitarray.h":
    ctypedef struct bitarrayobject:
        pass # we don't need to know the details
    
    ctypedef class bitarray.bitarray [object bitarrayobject]:
        pass
    
    int getbit(bitarray, int)
        
def sum_bits2(bitarray ba):
    cdef int count = 0
    cdef Py_ssize_t n
    for n in range(len(ba)):
        if getbit(ba, n):
            count += 1
    return count

要对其进行测试(并与简单的 Python 版本进行比较):

def sum_bits_naive(ba):
    count = 0
    for n in range(len(ba)):
        if ba[n]:
            count += 1
    return count

def test_funcs():
    from bitarray import bitarray
    
    ba = bitarray("110010"*10000)
    print(sum_bits1(ba), sum_bits2(ba), sum_bits_naive(ba))
    from timeit import timeit
    globs = dict(globals())
    globs.update(locals())
    print(timeit("sum_bits1(ba)", globals=globs, number=1000))
    print(timeit("sum_bits2(ba)", globals=globs, number=1000))
    print(timeit("sum_bits_naive(ba)", globals=globs, number=1000))

给予

(30000, 30000, 30000)
0.069798200041987
0.09307677199831232
1.3518586970167235

即memoryview 版本是最好的。