有没有更好的方法来删除长度等于或高于阈值的连续零部分?

Is there a better way to remove parts with consecutive zeros that have length equal or above threshold?

问题陈述:

如标题所述,我想从具有 consecutive zeros 的一维数组中删除部分长度等于或高于一个阈值.


我的解决方案:

我生成了如下 MRE 所示的解决方案:

import numpy as np

THRESHOLD = 4

a = np.array((1,1,0,1,0,0,0,0,1,1,0,0,0,1,0,0,0,0,0,1))

print("Input: " + str(a))

# Find the indices of the parts that meet threshold requirement
gaps_above_threshold_inds = np.where(np.diff(np.nonzero(a)[0]) - 1 >= THRESHOLD)[0]

# Delete these parts from array
for idx in gaps_above_threshold_inds:
    a = np.delete(a, list(range(np.nonzero(a)[0][idx] + 1, np.nonzero(a)[0][idx + 1])))
    
print("Output: " + str(a))

输出:

Input:  [1 1 0 1 0 0 0 0 1 1 0 0 0 1 0 0 0 0 0 1]
Output: [1 1 0 1 1 1 0 0 0 1 1]

问题:

是否有不那么复杂并且更有效的方法来在 numpy 数组上执行此操作?


编辑:

根据@mozway 的评论,我正在编辑我的问题以提供更多信息。

基本上,问题域是:

我的目标 是删除超过长度阈值的零部分,正如我已经说过的那样。

更详细的问题:



答案评论:

关于我对 高效 numpy 处理 的第一个担忧,@mathfux 的解决方案非常棒,基本上就是我一直在寻找的。这就是我接受这个的原因。

然而,@Jérôme Richard 的方法回答了我的第二个问题,它提供了一个非常高性能的解决方案;如果数据集非常大,真的很有用。

感谢您的精彩解答!

一个非常有效的方法是使用itertools.groupby+itertools.chain:

from itertools import groupby, chain
a2 = np.array(list(chain(*(l for k,g in groupby(a)
                           if len(l:=list(g))<THRESHOLD or k))))

输出:

array([1, 1, 0, 1, 1, 1, 0, 0, 0, 1, 1])

这工作相对较快,例如在 100 万个项目上:

# A = np.random.randint(2, size=1000000)
%%timeit
np.array(list(chain(*(l for k,g in groupby(a)
                      if len(l:=list(g))<THRESHOLD or k))))

# 254 ms ± 3.03 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

np.delete 每次调用都创建一个新数组,效率很低。更快的解决方案是将所有要保存的值存储在 mask/boolean 数组中,然后立即过滤输入数组。但是,如果只使用 Numpy,这仍然可能需要一个纯 Python 循环。一个更简单和更快的解决方案是使用 Numba(或 Cython)来做到这一点。这是一个实现:

import numpy as np
import numba as nb

@nb.njit('int_[:](int_[:], int_)')
def filterZeros(arr, threshold):
    n = len(arr)
    res = np.empty(n, dtype=arr.dtype)
    count = 0
    j = 0
    for i in range(n):
        if arr[i] == 0:
            count += 1
        else:
            if count >= threshold:
                j -= count
            count = 0
        res[j] = arr[i]
        j += 1
    if n > 0 and arr[n-1] == 0 and count >= threshold:
        j -= count
    return res[0:j]

a = np.array((1,1,0,1,0,0,0,0,1,1,0,0,0,1,0,0,0,0,0,1))
a = filterZeros(a, 4)
print("Output: " + str(a))

这是我机器上包含 100_000 项的随机二进制数组的结果:

Reference implementation: 5982    ms
Mozway's solution:          23.4  ms
This implementation:         0.11 ms

因此,该解决方案比初始解决方案快 54381,比 Mozway 快 212 倍。通过 in-place(销毁输入数组)并告诉 Numba 数组在内存中 contiguous,代码甚至可以快 30% (使用 ::1 而不是 :)。

也可以找到非零项的差异,修复超过阈值的项并以正确的方式重建序列。

def numpy_fix(a):
    # STEP 1. find indices of nonzero items: [0 1 3 8 9 13 19]
    idx = np.flatnonzero(a)
    # STEP 2. Find differences along these indices (also insert a leading zero): [0 1 2 5 1 4 6] 
    df = np.diff(idx, prepend=0)
    # STEP 3. Fix differences of indices larger than THRESHOLD: [0 1 2 1 1 4 1] 
    df[df>THRESHOLD] = 1
    # STEP 4. Given differences on indices, reconstruct indices themselves: [0 1 3 4 5 9 10]
    cs = np.cumsum(df)
    z = np.zeros(cs[-1]+1, dtype=int) # create a list of zeros
    z[cs] = 1 #pad it with ones within indices found
    return z
>>> numpy_fix(a)
array([1, 1, 0, 1, 1, 1, 0, 0, 0, 1, 1])

(请注意,只有当 a 没有前导或尾随零时才是正确的)

%timeit numpy_fix(np.tile(a, (1, 50000)))
39.3 ms ± 865 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)