使用 numpy.where 创建蒙版
Mask creation using numpy.where
我有创建掩码(布尔数组)的函数,我希望速度更快。
def get_validity_1(ts, times):
validity = numpy.zeros(len(ts))
indexes = []
for start, end in times:
index_start = numpy.argmax(ts >= start)
index_end = numpy.argmax(ts >= end)
indexes.append([index_start, index_end])
for start, end in indexes:
validity[start:end] = 1
return validity
res_1 = get_validity_1(numpy.linspace(0, 1, 100000000), numpy.array([[0.01, 0.1], [0.5, 0.8]]))
这个问题问的是如何使用numpy.where条件来实现。我试过这个:
def get_validity_2(ts, times):
return numpy.where(numpy.logical_or([t1<ts.all()<t2 for t1, t2 in times]))
但 python 提高:
ValueError: invalid number of arguments
这里是一些输入断言:
- ts[n-1] < ts[n]
- 次[n][0] < 次[n][1]
- 次[n-1][1] < 次[n][0]
这是一个脚本作为输入:
import time, numpy
def get_validity_1(ts, times):
validity = numpy.zeros(len(ts))
indexes = []
for start, end in times:
index_start = numpy.argmax(ts >= start)
index_end = numpy.argmax(ts >= end)
indexes.append([index_start, index_end])
for start, end in indexes:
validity[start:end] = 1
return validity
def get_validity_2(ts, times):
return numpy.where(numpy.logical_or([t1<ts.all()<t2 for t1, t2 in times]))
if __name__ == "__main__":
n = 100000000
ts = numpy.linspace(0, 1, n)
times = numpy.array([[0.01, 0.1], [0.5, 0.8]])
t0 = time.time()
res_1 = get_validity_1(ts, times)
t_1 = time.time() - t0
t0 = time.time()
res_2 = get_validity_2(ts, times)
t_2 = time.time() - t0
print("t_1: " + str(t_1))
print("t_2: " + str(t_2))
assert res_1 == res_2
assert t_1 > t_2
有谁知道如何完成函数 'get_validity_2' 并通过断言?
或者只是一个包的功能来解决这个问题?
np.logical_or(*[np.logical_and(t1<ts, ts<t2) for t1, t2 in times])
如果你想要像你想要达到的那样的 1-liner。但是,这仍然效率低下,因为您正在比较 O(N) 中的大型数组。
由于 ts 是排序的,这里有一个使用二进制搜索在 O(log(N)) 中更快地找到 start/end 索引的方法:
def get_validity_3(ts, times):
validity = numpy.zeros(len(ts))
for start, end in times:
index_start = np.searchsorted(ts, start)
index_end = np.searchsorted(ts, end)
validity[index_start:index_end] = 1
return validity
整体代码:
import time, numpy
def get_validity_1(ts, times):
validity = numpy.zeros(len(ts))
indexes = []
for start, end in times:
index_start = numpy.argmax(ts >= start)
index_end = numpy.argmax(ts >= end)
indexes.append([index_start, index_end])
for start, end in indexes:
validity[start:end] = 1
return validity
def get_validity_2(ts, times):
return np.logical_or(*[np.logical_and(t1<ts, ts<t2) for t1, t2 in times])
def get_validity_3(ts, times):
validity = numpy.zeros(len(ts))
for start, end in times:
index_start = np.searchsorted(ts, start)
index_end = np.searchsorted(ts, end)
validity[index_start:index_end] = 1
return validity
if __name__ == "__main__":
n = 100000000
ts = numpy.linspace(0, 1, n)
times = numpy.array([[0.01, 0.1], [0.5, 0.8]])
t0 = time.time()
res_1 = get_validity_1(ts, times)
t_1 = time.time() - t0
t0 = time.time()
res_2 = get_validity_2(ts, times)
t_2 = time.time() - t0
t0 = time.time()
res_3 = get_validity_3(ts, times)
t_3 = time.time() - t0
print("t_1: " + str(t_1))
print("t_2: " + str(t_2))
print("t_3: " + str(t_3))
assert (res_1 == res_2).all()
assert (res_1 == res_3).all()
输出:
t_1: 0.4412200450897217
t_2: 0.3446168899536133
t_3: 0.14597129821777344
我有创建掩码(布尔数组)的函数,我希望速度更快。
def get_validity_1(ts, times):
validity = numpy.zeros(len(ts))
indexes = []
for start, end in times:
index_start = numpy.argmax(ts >= start)
index_end = numpy.argmax(ts >= end)
indexes.append([index_start, index_end])
for start, end in indexes:
validity[start:end] = 1
return validity
res_1 = get_validity_1(numpy.linspace(0, 1, 100000000), numpy.array([[0.01, 0.1], [0.5, 0.8]]))
这个问题问的是如何使用numpy.where条件来实现。我试过这个:
def get_validity_2(ts, times):
return numpy.where(numpy.logical_or([t1<ts.all()<t2 for t1, t2 in times]))
但 python 提高:
ValueError: invalid number of arguments
这里是一些输入断言:
- ts[n-1] < ts[n]
- 次[n][0] < 次[n][1]
- 次[n-1][1] < 次[n][0]
这是一个脚本作为输入:
import time, numpy
def get_validity_1(ts, times):
validity = numpy.zeros(len(ts))
indexes = []
for start, end in times:
index_start = numpy.argmax(ts >= start)
index_end = numpy.argmax(ts >= end)
indexes.append([index_start, index_end])
for start, end in indexes:
validity[start:end] = 1
return validity
def get_validity_2(ts, times):
return numpy.where(numpy.logical_or([t1<ts.all()<t2 for t1, t2 in times]))
if __name__ == "__main__":
n = 100000000
ts = numpy.linspace(0, 1, n)
times = numpy.array([[0.01, 0.1], [0.5, 0.8]])
t0 = time.time()
res_1 = get_validity_1(ts, times)
t_1 = time.time() - t0
t0 = time.time()
res_2 = get_validity_2(ts, times)
t_2 = time.time() - t0
print("t_1: " + str(t_1))
print("t_2: " + str(t_2))
assert res_1 == res_2
assert t_1 > t_2
有谁知道如何完成函数 'get_validity_2' 并通过断言? 或者只是一个包的功能来解决这个问题?
np.logical_or(*[np.logical_and(t1<ts, ts<t2) for t1, t2 in times])
如果你想要像你想要达到的那样的 1-liner。但是,这仍然效率低下,因为您正在比较 O(N) 中的大型数组。
由于 ts 是排序的,这里有一个使用二进制搜索在 O(log(N)) 中更快地找到 start/end 索引的方法:
def get_validity_3(ts, times):
validity = numpy.zeros(len(ts))
for start, end in times:
index_start = np.searchsorted(ts, start)
index_end = np.searchsorted(ts, end)
validity[index_start:index_end] = 1
return validity
整体代码:
import time, numpy
def get_validity_1(ts, times):
validity = numpy.zeros(len(ts))
indexes = []
for start, end in times:
index_start = numpy.argmax(ts >= start)
index_end = numpy.argmax(ts >= end)
indexes.append([index_start, index_end])
for start, end in indexes:
validity[start:end] = 1
return validity
def get_validity_2(ts, times):
return np.logical_or(*[np.logical_and(t1<ts, ts<t2) for t1, t2 in times])
def get_validity_3(ts, times):
validity = numpy.zeros(len(ts))
for start, end in times:
index_start = np.searchsorted(ts, start)
index_end = np.searchsorted(ts, end)
validity[index_start:index_end] = 1
return validity
if __name__ == "__main__":
n = 100000000
ts = numpy.linspace(0, 1, n)
times = numpy.array([[0.01, 0.1], [0.5, 0.8]])
t0 = time.time()
res_1 = get_validity_1(ts, times)
t_1 = time.time() - t0
t0 = time.time()
res_2 = get_validity_2(ts, times)
t_2 = time.time() - t0
t0 = time.time()
res_3 = get_validity_3(ts, times)
t_3 = time.time() - t0
print("t_1: " + str(t_1))
print("t_2: " + str(t_2))
print("t_3: " + str(t_3))
assert (res_1 == res_2).all()
assert (res_1 == res_3).all()
输出:
t_1: 0.4412200450897217
t_2: 0.3446168899536133
t_3: 0.14597129821777344