计算 Numba 中 numpy 数组中非零值的数量
Count the number of non zero values in a numpy array in Numba
很简单。我正在尝试计算使用 Numba (njit()
) 编译的 NumPy jit 中数组中非零值的数量。 Numba 不允许以下我尝试过的内容。
a[a != 0].size
np.count_nonzero(a)
len(a[a != 0])
len(a) - len(a[a == 0])
如果还有更快、更 pythonic 和优雅的方式,我不想使用 for 循环。
对于想要查看完整代码示例的评论者...
import numpy as np
from numba import njit
@njit()
def n_nonzero(a):
return a[a != 0].size
你可以使用np.nonzero
并归纳出它的长度:
@njit
def count_non_zero(np_arr):
return len(np.nonzero(np_arr)[0])
count_non_zero(np.array([0,1,0,1]))
# 2
不确定我是否在这里犯了错误,但这似乎快了 6 倍:
# Make something worth checking
a=np.random.randint(0,3,1000000000,dtype=np.uint8)
In [41]: @njit()
...: def methodA(a):
...: return len(np.nonzero(a)[0])
# Call and check result
In [42]: methodA(a)
Out[42]: 666644445
In [43]: %timeit methodA(a)
4.65 s ± 28.6 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
In [44]: @njit()
...: def methodB(a):
...: return (a!=0).sum()
# Call and check result
In [45]: methodB(a)
Out[45]: 666644445
In [46]: %timeit methodB(a)
724 ms ± 14 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
您也可以考虑计算非零值:
import numba as nb
@nb.njit()
def count_loop(a):
s = 0
for i in a:
if i != 0:
s += 1
return s
我知道这似乎不对,但请耐心等待:
import numpy as np
import numba as nb
@nb.njit()
def count_loop(a):
s = 0
for i in a:
if i != 0:
s += 1
return s
@nb.njit()
def count_len_nonzero(a):
return len(np.nonzero(a)[0])
@nb.njit()
def count_sum_neq_zero(a):
return (a != 0).sum()
np.random.seed(100)
a = np.random.randint(0, 3, 1000000000, dtype=np.uint8)
c = np.count_nonzero(a)
assert count_len_nonzero(a) == c
assert count_sum_neq_zero(a) == c
assert count_loop(a) == c
%timeit count_len_nonzero(a)
# 5.94 s ± 141 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
%timeit count_sum_neq_zero(a)
# 848 ms ± 80.3 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
%timeit count_loop(a)
# 189 ms ± 4.41 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
它实际上比 np.count_nonzero
快,由于某些原因可能会变得很慢:
%timeit np.count_nonzero(a)
# 4.36 s ± 69.4 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
如果你需要它对大型阵列非常快,你甚至可以使用 numbas prange
并行处理计数(对于小型阵列,由于并行处理开销,它会更慢)。
import numpy as np
from numba import njit, prange
@njit(parallel=True)
def parallel_nonzero_count(arr):
flattened = arr.ravel()
sum_ = 0
for i in prange(flattened.size):
sum_ += flattened[i] != 0
return sum_
请注意,当您使用 numba 时,您通常希望写出循环,因为这正是 numba 非常擅长优化的地方。
我实际上是根据这里提到的其他解决方案来计时的(使用我的 Python 模块 simple_benchmark
):
重现代码:
import numpy as np
from numba import njit, prange
@njit
def n_nonzero(a):
return a[a != 0].size
@njit
def count_non_zero(np_arr):
return len(np.nonzero(np_arr)[0])
@njit()
def methodB(a):
return (a!=0).sum()
@njit(parallel=True)
def parallel_nonzero_count(arr):
flattened = arr.ravel()
sum_ = 0
for i in prange(flattened.size):
sum_ += flattened[i] != 0
return sum_
@njit()
def count_loop(a):
s = 0
for i in a:
if i != 0:
s += 1
return s
from simple_benchmark import benchmark
args = {}
for exp in range(2, 20):
size = 2**exp
arr = np.random.random(size)
arr[arr < 0.3] = 0.0
args[size] = arr
b = benchmark(
funcs=(n_nonzero, count_non_zero, methodB, np.count_nonzero, parallel_nonzero_count, count_loop),
arguments=args,
argument_name='array size',
warmups=(n_nonzero, count_non_zero, methodB, np.count_nonzero, parallel_nonzero_count, count_loop)
)
很简单。我正在尝试计算使用 Numba (njit()
) 编译的 NumPy jit 中数组中非零值的数量。 Numba 不允许以下我尝试过的内容。
a[a != 0].size
np.count_nonzero(a)
len(a[a != 0])
len(a) - len(a[a == 0])
如果还有更快、更 pythonic 和优雅的方式,我不想使用 for 循环。
对于想要查看完整代码示例的评论者...
import numpy as np
from numba import njit
@njit()
def n_nonzero(a):
return a[a != 0].size
你可以使用np.nonzero
并归纳出它的长度:
@njit
def count_non_zero(np_arr):
return len(np.nonzero(np_arr)[0])
count_non_zero(np.array([0,1,0,1]))
# 2
不确定我是否在这里犯了错误,但这似乎快了 6 倍:
# Make something worth checking
a=np.random.randint(0,3,1000000000,dtype=np.uint8)
In [41]: @njit()
...: def methodA(a):
...: return len(np.nonzero(a)[0])
# Call and check result
In [42]: methodA(a)
Out[42]: 666644445
In [43]: %timeit methodA(a)
4.65 s ± 28.6 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
In [44]: @njit()
...: def methodB(a):
...: return (a!=0).sum()
# Call and check result
In [45]: methodB(a)
Out[45]: 666644445
In [46]: %timeit methodB(a)
724 ms ± 14 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
您也可以考虑计算非零值:
import numba as nb
@nb.njit()
def count_loop(a):
s = 0
for i in a:
if i != 0:
s += 1
return s
我知道这似乎不对,但请耐心等待:
import numpy as np
import numba as nb
@nb.njit()
def count_loop(a):
s = 0
for i in a:
if i != 0:
s += 1
return s
@nb.njit()
def count_len_nonzero(a):
return len(np.nonzero(a)[0])
@nb.njit()
def count_sum_neq_zero(a):
return (a != 0).sum()
np.random.seed(100)
a = np.random.randint(0, 3, 1000000000, dtype=np.uint8)
c = np.count_nonzero(a)
assert count_len_nonzero(a) == c
assert count_sum_neq_zero(a) == c
assert count_loop(a) == c
%timeit count_len_nonzero(a)
# 5.94 s ± 141 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
%timeit count_sum_neq_zero(a)
# 848 ms ± 80.3 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
%timeit count_loop(a)
# 189 ms ± 4.41 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
它实际上比 np.count_nonzero
快,由于某些原因可能会变得很慢:
%timeit np.count_nonzero(a)
# 4.36 s ± 69.4 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
如果你需要它对大型阵列非常快,你甚至可以使用 numbas prange
并行处理计数(对于小型阵列,由于并行处理开销,它会更慢)。
import numpy as np
from numba import njit, prange
@njit(parallel=True)
def parallel_nonzero_count(arr):
flattened = arr.ravel()
sum_ = 0
for i in prange(flattened.size):
sum_ += flattened[i] != 0
return sum_
请注意,当您使用 numba 时,您通常希望写出循环,因为这正是 numba 非常擅长优化的地方。
我实际上是根据这里提到的其他解决方案来计时的(使用我的 Python 模块 simple_benchmark
):
重现代码:
import numpy as np
from numba import njit, prange
@njit
def n_nonzero(a):
return a[a != 0].size
@njit
def count_non_zero(np_arr):
return len(np.nonzero(np_arr)[0])
@njit()
def methodB(a):
return (a!=0).sum()
@njit(parallel=True)
def parallel_nonzero_count(arr):
flattened = arr.ravel()
sum_ = 0
for i in prange(flattened.size):
sum_ += flattened[i] != 0
return sum_
@njit()
def count_loop(a):
s = 0
for i in a:
if i != 0:
s += 1
return s
from simple_benchmark import benchmark
args = {}
for exp in range(2, 20):
size = 2**exp
arr = np.random.random(size)
arr[arr < 0.3] = 0.0
args[size] = arr
b = benchmark(
funcs=(n_nonzero, count_non_zero, methodB, np.count_nonzero, parallel_nonzero_count, count_loop),
arguments=args,
argument_name='array size',
warmups=(n_nonzero, count_non_zero, methodB, np.count_nonzero, parallel_nonzero_count, count_loop)
)