numba 中的整数数组
Array of ints in numba
我正在计算 int8
向量中出现频率最高的数字。当我设置 int
s:
的计数器数组时,Numba 抱怨
@jit(nopython=True)
def freq_int8(y):
"""Find most frequent number in array"""
count = np.zeros(256, dtype=int)
for val in y:
count[val] += 1
return ((np.argmax(count)+128) % 256) - 128
调用它时出现以下错误:
TypingError: Invalid usage of Function(<built-in function zeros>) with parameters (int64, Function(<class 'int'>))
如果我删除 dtype=int
它会起作用并且我得到了不错的加速。然而,我对为什么声明 int
s 的数组不起作用感到困惑。是否有已知的解决方法,是否有任何值得在这里获得的效率提升?
背景:我正在尝试从一些 numpy-heavy 代码中减少微秒。我尤其受到 numpy.median
的伤害,并且一直在研究 Numba,但正在努力改进 median
。找到最频繁的数字是 median
的一个可接受的替代方法,在这里我已经能够获得一些性能。上面的numba代码也比numpy.bincount
.
快
更新: 在输入已接受的答案后,这里是 int8
向量的 median
的实现。它大约比 numpy.median
:
快一个数量级
@jit(nopython=True)
def median_int8(y):
N2 = len(y)//2
count = np.zeros(256, dtype=np.int32)
for val in y:
count[val] += 1
cs = 0
for i in range(-128, 128):
cs += count[i]
if cs > N2:
return float(i)
elif cs == N2:
j = i+1
while count[j] == 0:
j += 1
return (i + j)/2
令人惊讶的是,短向量的性能差异甚至更大,显然是由于 numpy
向量的开销:
>>> a = np.random.randint(-128, 128, 10)
>>> %timeit np.median(a)
The slowest run took 7.03 times longer than the fastest. This could mean that an intermediate result is being cached.
10000 loops, best of 3: 20.8 µs per loop
>>> %timeit median_int8(a)
The slowest run took 11.67 times longer than the fastest. This could mean that an intermediate result is being cached.
1000000 loops, best of 3: 593 ns per loop
这个开销太大了,不知道是不是哪里出了问题
快速说明一下,找到出现次数最多的数字通常称为 mode, and it is as similar to the median as it is the mean... in which case np.mean
will be considerably faster. Unless you have some constrains or particularities in your data, there is no guarantee that the mode approximates the median。
如果你仍然想计算整数列表的 mode,正如你提到的,np.bincount
应该足够了(如果 numba 更快,它应该不会太多):
count = np.bincount(y, minlength=256)
result = ((np.argmax(count)+128) % 256) - 128
请注意,我已将 minlength
参数添加到 np.bincount
,因此它 return 与代码中的 256 长度列表相同。但在实践中完全没有必要,因为你只想要 argmax
、np.bincount
(没有 minlength
)将 return 一个列表,其长度是 [=18] 中的最大数字=].
至于 numba 错误,将 dtype=int
替换为 dtype=np.int32
应该可以解决问题。 int
是一个 python 函数,您在 numba header 中指定了 nopython
。如果您删除 nopython
,那么 dtype=int
或 dtype='i'
也将起作用(具有相同的效果)。
我正在计算 int8
向量中出现频率最高的数字。当我设置 int
s:
@jit(nopython=True)
def freq_int8(y):
"""Find most frequent number in array"""
count = np.zeros(256, dtype=int)
for val in y:
count[val] += 1
return ((np.argmax(count)+128) % 256) - 128
调用它时出现以下错误:
TypingError: Invalid usage of Function(<built-in function zeros>) with parameters (int64, Function(<class 'int'>))
如果我删除 dtype=int
它会起作用并且我得到了不错的加速。然而,我对为什么声明 int
s 的数组不起作用感到困惑。是否有已知的解决方法,是否有任何值得在这里获得的效率提升?
背景:我正在尝试从一些 numpy-heavy 代码中减少微秒。我尤其受到 numpy.median
的伤害,并且一直在研究 Numba,但正在努力改进 median
。找到最频繁的数字是 median
的一个可接受的替代方法,在这里我已经能够获得一些性能。上面的numba代码也比numpy.bincount
.
更新: 在输入已接受的答案后,这里是 int8
向量的 median
的实现。它大约比 numpy.median
:
@jit(nopython=True)
def median_int8(y):
N2 = len(y)//2
count = np.zeros(256, dtype=np.int32)
for val in y:
count[val] += 1
cs = 0
for i in range(-128, 128):
cs += count[i]
if cs > N2:
return float(i)
elif cs == N2:
j = i+1
while count[j] == 0:
j += 1
return (i + j)/2
令人惊讶的是,短向量的性能差异甚至更大,显然是由于 numpy
向量的开销:
>>> a = np.random.randint(-128, 128, 10)
>>> %timeit np.median(a)
The slowest run took 7.03 times longer than the fastest. This could mean that an intermediate result is being cached.
10000 loops, best of 3: 20.8 µs per loop
>>> %timeit median_int8(a)
The slowest run took 11.67 times longer than the fastest. This could mean that an intermediate result is being cached.
1000000 loops, best of 3: 593 ns per loop
这个开销太大了,不知道是不是哪里出了问题
快速说明一下,找到出现次数最多的数字通常称为 mode, and it is as similar to the median as it is the mean... in which case np.mean
will be considerably faster. Unless you have some constrains or particularities in your data, there is no guarantee that the mode approximates the median。
如果你仍然想计算整数列表的 mode,正如你提到的,np.bincount
应该足够了(如果 numba 更快,它应该不会太多):
count = np.bincount(y, minlength=256)
result = ((np.argmax(count)+128) % 256) - 128
请注意,我已将 minlength
参数添加到 np.bincount
,因此它 return 与代码中的 256 长度列表相同。但在实践中完全没有必要,因为你只想要 argmax
、np.bincount
(没有 minlength
)将 return 一个列表,其长度是 [=18] 中的最大数字=].
至于 numba 错误,将 dtype=int
替换为 dtype=np.int32
应该可以解决问题。 int
是一个 python 函数,您在 numba header 中指定了 nopython
。如果您删除 nopython
,那么 dtype=int
或 dtype='i'
也将起作用(具有相同的效果)。