计算数组快速方法中值的平均位置

Calculate the average position of a value in an array fast method

我有以下代码来计算 1s 在包含 1 和 0 的二维 numpy 数组中的平均位置。问题是它非常慢,我想知道是否有更快的方法?

row_sum = 0
col_sum = 0
ones_count = 0

for row_count, row in enumerate(array):
    for col_count, col in enumerate(row):
        if col == 1:
            row_sum += row_count
            col_sum += col_count
            ones_count += 1

average_position_ones = (row_sum / ones_count, col_sum / ones_count)

查看您的代码,您可以通过 np.sum() 获得数组的总和(前提是数组仅包含 0/1):

ones_count = array.sum()

print((arr.shape[0] - 1) / ones_count, (arr.shape[1] - 1) / ones_count)

这里有 3 种方法可以更快地计算 row_sumcol_sumones_count

基线

我使用这个数组进行测试

import numpy as np
import numba as nb

np.random.seed(1)

n = 10**4
array = np.random.randint(0,2,(n,n))

现在你的确切代码在我的机器上需要 20.3 s ± 397 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

The Lazy One Liner Numpy 版本:

%timeit np.stack(np.where(array)).sum(axis=1),array.sum() 在我的机器上需要 1.13 s ± 12.1 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

这里 np.stack(np.where(array)).sum(axis=1) 就是你所说的 row_sumcol_sumarray.sum() 给你的 ones_count

避免循环投掷两次

您可以使用您的确切代码 numba.jit

@nb.njit
def test():
    row_sum = 0
    col_sum = 0
    ones_count = 0

    for row_count, row in enumerate(array):
        for col_count, col in enumerate(row):
            if col == 1:
                row_sum += row_count
                col_sum += col_count
                ones_count += 1

    return row_sum,col_sum,ones_count

%timeit test()

这个速度有点快。在我的机器上需要 50 ms ± 614 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)。但绝对不值得付出努力。

多核版本

对您的代码稍加修改就可以 运行 多线程 numba

@nb.njit(parallel=True)
def test2():
    row_sum = 0
    col_sum = 0
    ones_count = 0
    
    for row_count in nb.prange(len(array)):
        row = array[row_count]
        for col_count, col in enumerate(row):
            if col == 1:
                row_sum += row_count
                col_sum += col_count
                ones_count += 1

    return row_sum,col_sum,ones_count

%timeit test2()

现在,与惰性 numpy 版本相比,这确实提供了一点速度。在我的 10 核机器上需要 13.3 ms ± 2.14 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)。虽然它没有使用全部 10 个内核。

请注意,并行修改内容时必须小心。您可以创建竞争条件。而这里的情况并非如此,只是因为 numba 针对这种特定情况采取了反制措施。

进一步优化

正如 Jérôme Richard 在评论中指出的那样。可以通过使用 uint8 而不是默认的 int64 来优化最后一个版本。只需在数组上调用 .astype(np.uint8) 即可。然后在我的机器上需要 9.38 ms ± 935 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)