使用 numpy.where 创建蒙版

Mask creation using numpy.where

我有创建掩码(布尔数组)的函数,我希望速度更快。

def get_validity_1(ts, times):
    validity = numpy.zeros(len(ts))
    indexes = []
    for start, end in times:
        index_start = numpy.argmax(ts >= start)
        index_end = numpy.argmax(ts >= end)
        indexes.append([index_start, index_end])
    for start, end in indexes:
        validity[start:end] = 1
    return validity
res_1 = get_validity_1(numpy.linspace(0, 1, 100000000), numpy.array([[0.01, 0.1], [0.5, 0.8]]))

这个问题问的是如何使用numpy.where条件来实现。我试过这个:

def get_validity_2(ts, times):
    return numpy.where(numpy.logical_or([t1<ts.all()<t2 for t1, t2 in times]))

但 python 提高:

ValueError: invalid number of arguments

这里是一些输入断言:

这是一个脚本作为输入:

import time, numpy

def get_validity_1(ts, times):
    validity = numpy.zeros(len(ts))
    indexes = []
    for start, end in times:
        index_start = numpy.argmax(ts >= start)
        index_end = numpy.argmax(ts >= end)
        indexes.append([index_start, index_end])
    for start, end in indexes:
        validity[start:end] = 1
    return validity
    
def get_validity_2(ts, times):
    return numpy.where(numpy.logical_or([t1<ts.all()<t2 for t1, t2 in times]))

if __name__ == "__main__":
    n = 100000000
    ts = numpy.linspace(0, 1, n)
    
    times = numpy.array([[0.01, 0.1], [0.5, 0.8]])
    
    t0 = time.time()
    res_1 = get_validity_1(ts, times)
    t_1 = time.time() - t0
    
    t0 = time.time()
    res_2 = get_validity_2(ts, times)
    t_2 = time.time() - t0
    
    print("t_1: " + str(t_1))
    print("t_2: " + str(t_2))
    
    assert res_1 == res_2
    assert t_1 > t_2

有谁知道如何完成函数 'get_validity_2' 并通过断言? 或者只是一个包的功能来解决这个问题?

np.logical_or(*[np.logical_and(t1<ts, ts<t2) for t1, t2 in times])

如果你想要像你想要达到的那样的 1-liner。但是,这仍然效率低下,因为您正在比较 O(N) 中的大型数组。

由于 ts 是排序的,这里有一个使用二进制搜索在 O(log(N)) 中更快地找到 start/end 索引的方法:

def get_validity_3(ts, times):
    validity = numpy.zeros(len(ts))
    for start, end in times:
        index_start = np.searchsorted(ts, start)
        index_end = np.searchsorted(ts, end)
        validity[index_start:index_end] = 1
    return validity

整体代码:

import time, numpy

def get_validity_1(ts, times):
    validity = numpy.zeros(len(ts))
    indexes = []
    for start, end in times:
        index_start = numpy.argmax(ts >= start)
        index_end = numpy.argmax(ts >= end)
        indexes.append([index_start, index_end])
    for start, end in indexes:
        validity[start:end] = 1
    return validity
    
def get_validity_2(ts, times):
    return np.logical_or(*[np.logical_and(t1<ts, ts<t2) for t1, t2 in times])
    
def get_validity_3(ts, times):
    validity = numpy.zeros(len(ts))
    for start, end in times:
        index_start = np.searchsorted(ts, start)
        index_end = np.searchsorted(ts, end)
        validity[index_start:index_end] = 1
    return validity
    

if __name__ == "__main__":
    n = 100000000
    ts = numpy.linspace(0, 1, n)
    
    times = numpy.array([[0.01, 0.1], [0.5, 0.8]])
    
    t0 = time.time()
    res_1 = get_validity_1(ts, times)
    t_1 = time.time() - t0
    
    t0 = time.time()
    res_2 = get_validity_2(ts, times)
    t_2 = time.time() - t0
    
    t0 = time.time()
    res_3 = get_validity_3(ts, times)
    t_3 = time.time() - t0
    
    print("t_1: " + str(t_1))
    print("t_2: " + str(t_2))
    print("t_3: " + str(t_3))
    
    assert (res_1 == res_2).all()
    assert (res_1 == res_3).all()

输出:

t_1: 0.4412200450897217
t_2: 0.3446168899536133
t_3: 0.14597129821777344