累积和> x时拆分数组的Numpy方法

Numpy way of splitting array when cumaltive sum > x

数据

让我们采用以下二维数组:

starts = [0, 4, 10, 13, 23, 27]
ends = [4, 10, 13, 23, 27, 32]
lengths = [4, 6, 3, 10, 4, 5] 

arr = np.array([starts, ends, lengths]).T

因此看起来像:

[[ 0  4  4]
 [ 4 10  6]
 [10 13  3]
 [13 23 10]
 [23 27  4]
 [27 32  5]]

目标

现在我想“循环”通过 lengths 并且一旦累积和达到 10 我想输出 startsends然后重新开始累计计数。


工作代码

tot_size = 0
start = 0

for i, numb in enumerate(arr[:,-1]):
    # update sum
    tot_size += numb

    # Check if target size is reached
    if tot_size >= 10:
        start_loc, end_loc = arr[:,0][start],  arr[:,1][i]
        print('Start: {}\nEnd: {}\nSum: {}\n'.format(start_loc, end_loc, tot_size))
        start = i + 1
        tot_size = 0

# Last part
start_loc, end_loc = arr[:,0][start],  arr[:,1][i]
print('Start: {}\nEnd: {}\nSum: {}\n'.format(start_loc, end_loc, tot_size))

将打印:

Start: 0 End: 10 Sum: 10

Start: 10 End: 23 Sum: 13

Start: 23 End: 32 Sum: 9

(我不需要知道结果总和,但我确实需要知道 startsends)


Numpy 尝试

我想肯定有一种更直接或矢量化的方法来使用 numpy 来完成此操作。

我在想像 np.remainder(np.cumsum(arr[:,-1]), 10) 这样的东西,但是很难说什么时候 接近 到目标数字(10 这里),这与 sum > x

时的拆分不同

因为上面的方法在 window 中不起作用,我想到了 stides 但这些 windows 是固定大小的

欢迎所有想法:)

Numpy 并不是为有效解决此类问题而设计的。您仍然可以使用一些技巧或 cumsum + 除法 + diff + where 或类似组合(如@Kevin 提议)的通常组合来解决此问题,但 AFAIK 它们都是低效的。实际上,它们需要许多临时数组昂贵的操作

临时数组很昂贵有两个原因:对于小数组,Numpy 函数的开销通常是每次调用几微秒,导致整个操作通常需要几十微秒;对于大数组,每个操作都将受内存限制,并且内存带宽在现代平台上很小。实际上,它甚至更糟糕,因为由于页面错误,写入新分配的数组要慢得多,而且 Numpy 数组写入目前在大多数平台上都没有优化(包括主流的 x86-64 平台)。

至于“昂贵的操作”,这包括在 O(n log n) 中运行的排序(默认使用 quick-sort)并且通常受内存限制,找到唯一值(目前在内部进行排序) 和整数除法,这是众所周知的非常慢的。


解决此问题的一种方法是使用 Numba(或 Cython)。 Numba 使用 just-in-time 编译器来编写快速优化的函数。编写自己的高效基本 Numpy built-ins 特别有用。这是一个基于您的代码的示例:

import numba as nb
import numpy as np

@nb.njit(['(int32[:,:],)', '(int64[:,:],)'])
def compute(arr):
    n = len(arr)
    tot_size, start, cur = 0, 0, 0
    slices = np.empty((n, 2), arr.dtype)

    for i in range(n):
        tot_size += arr[i, 2]

        if tot_size >= 10:
            slices[cur, 0] = arr[start, 0]
            slices[cur, 1] = arr[i, 1]
            start = i + 1
            cur += 1
            tot_size = 0

    slices[cur, 0] = arr[start, 0]
    slices[cur, 1] = arr[i, 1]
    return slices[:cur+1]

对于您的小示例,Numba 函数在我的机器上大约需要 0.75 us,而初始解决方案需要 3.5 us。相比之下,@Kevin 提供的 Numpy 解决方案(返回索引)需要 24 us 用于 np.unique 和 6 us 用于 division-based 解决方案。事实上,基本的 np.cumsum 在我的机器上已经占用了 0.65 us。因此,Numba 解决方案是最快的。对于较大的阵列尤其如此。