numba 中的整数数组

Array of ints in numba

我正在计算 int8 向量中出现频率最高的数字。当我设置 ints:

的计数器数组时,Numba 抱怨
@jit(nopython=True)
def freq_int8(y):
    """Find most frequent number in array"""
    count = np.zeros(256, dtype=int)
    for val in y:
        count[val] += 1
    return ((np.argmax(count)+128) % 256) - 128

调用它时出现以下错误:

TypingError: Invalid usage of Function(<built-in function zeros>) with parameters (int64, Function(<class 'int'>))

如果我删除 dtype=int 它会起作用并且我得到了不错的加速。然而,我对为什么声明 ints 的数组不起作用感到困惑。是否有已知的解决方法,是否有任何值得在这里获得的效率提升?

背景:我正在尝试从一些 numpy-heavy 代码中减少微秒。我尤其受到 numpy.median 的伤害,并且一直在研究 Numba,但正在努力改进 median。找到最频繁的数字是 median 的一个可接受的替代方法,在这里我已经能够获得一些性能。上面的numba代码也比numpy.bincount.

更新: 在输入已接受的答案后,这里是 int8 向量的 median 的实现。它大约比 numpy.median:

快一个数量级
@jit(nopython=True)
def median_int8(y):
    N2 = len(y)//2
    count = np.zeros(256, dtype=np.int32)
    for val in y:
        count[val] += 1
    cs = 0
    for i in range(-128, 128):
        cs += count[i]
        if cs > N2:
            return float(i)
        elif cs == N2:
            j = i+1
            while count[j] == 0:
                j += 1
            return (i + j)/2

令人惊讶的是,短向量的性能差异甚至更大,显然是由于 numpy 向量的开销:

>>> a = np.random.randint(-128, 128, 10)

>>> %timeit np.median(a)
    The slowest run took 7.03 times longer than the fastest. This could mean that an intermediate result is being cached.
    10000 loops, best of 3: 20.8 µs per loop

>>> %timeit median_int8(a)
    The slowest run took 11.67 times longer than the fastest. This could mean that an intermediate result is being cached.
    1000000 loops, best of 3: 593 ns per loop

这个开销太大了,不知道是不是哪里出了问题

快速说明一下,找到出现次数最多的数字通常称为 mode, and it is as similar to the median as it is the mean... in which case np.mean will be considerably faster. Unless you have some constrains or particularities in your data, there is no guarantee that the mode approximates the median

如果你仍然想计算整数列表的 mode,正如你提到的,np.bincount 应该足够了(如果 numba 更快,它应该不会太多):

count = np.bincount(y, minlength=256)
result = ((np.argmax(count)+128) % 256) - 128

请注意,我已将 minlength 参数添加到 np.bincount,因此它 return 与代码中的 256 长度列表相同。但在实践中完全没有必要,因为你只想要 argmaxnp.bincount(没有 minlength)将 return 一个列表,其长度是 [=18] 中的最大数字=].

至于 numba 错误,将 dtype=int 替换为 dtype=np.int32 应该可以解决问题。 int 是一个 python 函数,您在 numba header 中指定了 nopython。如果您删除 nopython,那么 dtype=intdtype='i' 也将起作用(具有相同的效果)。