计算数组快速方法中值的平均位置
Calculate the average position of a value in an array fast method
我有以下代码来计算 1s 在包含 1 和 0 的二维 numpy 数组中的平均位置。问题是它非常慢,我想知道是否有更快的方法?
row_sum = 0
col_sum = 0
ones_count = 0
for row_count, row in enumerate(array):
for col_count, col in enumerate(row):
if col == 1:
row_sum += row_count
col_sum += col_count
ones_count += 1
average_position_ones = (row_sum / ones_count, col_sum / ones_count)
查看您的代码,您可以通过 np.sum()
获得数组的总和(前提是数组仅包含 0/1):
ones_count = array.sum()
print((arr.shape[0] - 1) / ones_count, (arr.shape[1] - 1) / ones_count)
这里有 3 种方法可以更快地计算 row_sum
、col_sum
和 ones_count
。
基线
我使用这个数组进行测试
import numpy as np
import numba as nb
np.random.seed(1)
n = 10**4
array = np.random.randint(0,2,(n,n))
现在你的确切代码在我的机器上需要 20.3 s ± 397 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
。
The Lazy One Liner Numpy 版本:
%timeit np.stack(np.where(array)).sum(axis=1),array.sum()
在我的机器上需要 1.13 s ± 12.1 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
。
这里 np.stack(np.where(array)).sum(axis=1)
就是你所说的 row_sum
和 col_sum
和 array.sum()
给你的 ones_count
避免循环投掷两次
您可以使用您的确切代码 numba.jit
@nb.njit
def test():
row_sum = 0
col_sum = 0
ones_count = 0
for row_count, row in enumerate(array):
for col_count, col in enumerate(row):
if col == 1:
row_sum += row_count
col_sum += col_count
ones_count += 1
return row_sum,col_sum,ones_count
%timeit test()
这个速度有点快。在我的机器上需要 50 ms ± 614 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
。但绝对不值得付出努力。
多核版本
对您的代码稍加修改就可以 运行 多线程 numba
@nb.njit(parallel=True)
def test2():
row_sum = 0
col_sum = 0
ones_count = 0
for row_count in nb.prange(len(array)):
row = array[row_count]
for col_count, col in enumerate(row):
if col == 1:
row_sum += row_count
col_sum += col_count
ones_count += 1
return row_sum,col_sum,ones_count
%timeit test2()
现在,与惰性 numpy
版本相比,这确实提供了一点速度。在我的 10 核机器上需要 13.3 ms ± 2.14 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
。虽然它没有使用全部 10 个内核。
请注意,并行修改内容时必须小心。您可以创建竞争条件。而这里的情况并非如此,只是因为 numba
针对这种特定情况采取了反制措施。
进一步优化
正如 Jérôme Richard 在评论中指出的那样。可以通过使用 uint8 而不是默认的 int64 来优化最后一个版本。只需在数组上调用 .astype(np.uint8)
即可。然后在我的机器上需要 9.38 ms ± 935 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)
。
我有以下代码来计算 1s 在包含 1 和 0 的二维 numpy 数组中的平均位置。问题是它非常慢,我想知道是否有更快的方法?
row_sum = 0
col_sum = 0
ones_count = 0
for row_count, row in enumerate(array):
for col_count, col in enumerate(row):
if col == 1:
row_sum += row_count
col_sum += col_count
ones_count += 1
average_position_ones = (row_sum / ones_count, col_sum / ones_count)
查看您的代码,您可以通过 np.sum()
获得数组的总和(前提是数组仅包含 0/1):
ones_count = array.sum()
print((arr.shape[0] - 1) / ones_count, (arr.shape[1] - 1) / ones_count)
这里有 3 种方法可以更快地计算 row_sum
、col_sum
和 ones_count
。
基线
我使用这个数组进行测试
import numpy as np
import numba as nb
np.random.seed(1)
n = 10**4
array = np.random.randint(0,2,(n,n))
现在你的确切代码在我的机器上需要 20.3 s ± 397 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
。
The Lazy One Liner Numpy 版本:
%timeit np.stack(np.where(array)).sum(axis=1),array.sum()
在我的机器上需要 1.13 s ± 12.1 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
。
这里 np.stack(np.where(array)).sum(axis=1)
就是你所说的 row_sum
和 col_sum
和 array.sum()
给你的 ones_count
避免循环投掷两次
您可以使用您的确切代码 numba.jit
@nb.njit
def test():
row_sum = 0
col_sum = 0
ones_count = 0
for row_count, row in enumerate(array):
for col_count, col in enumerate(row):
if col == 1:
row_sum += row_count
col_sum += col_count
ones_count += 1
return row_sum,col_sum,ones_count
%timeit test()
这个速度有点快。在我的机器上需要 50 ms ± 614 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
。但绝对不值得付出努力。
多核版本
对您的代码稍加修改就可以 运行 多线程 numba
@nb.njit(parallel=True)
def test2():
row_sum = 0
col_sum = 0
ones_count = 0
for row_count in nb.prange(len(array)):
row = array[row_count]
for col_count, col in enumerate(row):
if col == 1:
row_sum += row_count
col_sum += col_count
ones_count += 1
return row_sum,col_sum,ones_count
%timeit test2()
现在,与惰性 numpy
版本相比,这确实提供了一点速度。在我的 10 核机器上需要 13.3 ms ± 2.14 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
。虽然它没有使用全部 10 个内核。
请注意,并行修改内容时必须小心。您可以创建竞争条件。而这里的情况并非如此,只是因为 numba
针对这种特定情况采取了反制措施。
进一步优化
正如 Jérôme Richard 在评论中指出的那样。可以通过使用 uint8 而不是默认的 int64 来优化最后一个版本。只需在数组上调用 .astype(np.uint8)
即可。然后在我的机器上需要 9.38 ms ± 935 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)
。