使用 cython 或替代方法优化 np.searchsorted()

Optimising np.searchsorted() with cython or alternative

我有以下功能。我试图优化 np.searchsorted() 函数,并在 Cython 中编写了我自己的函数并取得了很好的结果。但后来意识到将前者简单地包装在 cdef cfunc() 中效果更好,而且我不了解 python、cython 或 numpy 的机制,不知道为什么。什么是最好的 fastest/efficient 编写此函数的方法。在我的脚本中,这是在某些计算过程中被调用数千次的函数之一,我希望每次调用的成本尽可能低。此外,我不需要完全复制 np.searchsorted(),我只需要生成特定的输出 i = np.searchsorted(time_points > t, True)

案例一

%%cython --compile-args=-fopenmp --link-args=-fopenmp -a
cimport openmp
cimport cython
#cimport openmp
"""
%%cython -a
"""
from cpython cimport array
import array
import cython.parallel as cp
from cython.parallel import parallel, prange
import numpy as np
cimport numpy as np
import random
import pickle
import matplotlib.pyplot as pl

from timeit import default_timer as timer
cdef int ntime = 10**4
cdef float t = 2 #some variable in general 
time_points = np.arange(ntime, dtype=np.float64)
#print(time_points)
#search_time =[]
start_update = timer()
i = np.searchsorted(time_points > t, True)
print(i)
end_update = timer()
print(end_update-start_update)

案例二

@cython.boundscheck(False)  # Deactivate bounds checking                                                                  
@cython.wraparound(False)   # Deactivate negative indexing.                                                               
@cython.cdivision(True)     # Deactivate division by 0 checking.
cdef mysearch(np.ndarray[np.float64_t, ndim=1] time_points, float val, int nt):
    cdef int idx
    cdef int return_val
    #cdef int total 
    if val >= nt or nt-1< val<nt:
        print("first")
        return nt
    else:
        for idx in range(nt):
            #print(idx)
            
            if time_points[idx] <= val:
                #return_val = np.int(time_points[idx])
                #break
                #print("sec")
                continue
            else:
                #print("third", idx)
                return_val = np.int(time_points[idx])
                #continue
                break
        
    return return_val
print(mysearch(time_points, t, ntime))
#i = np.searchsorted(time_points > 2, True)
#print(i)
end_update = timer()
b= end_update-start_update
print(b)

案例三

start_update = timer()
cdef cfunc():
    i = np.searchsorted(time_points > t, True)
    print(i, np.searchsorted(time_points > t, True))
    return i
end_update = timer()
a=  end_update-start_update
print(a)

案例一、案例二、案例三的时间是0.00013096071779727936 5.245860666036606e-056.770715117454529e-07

我很高兴案例 2 比案例 1 更好,但为什么案例 3 快得多?有什么特别的事情发生吗?我怎样才能超越这些速度?

np.searchsorted(time_points > t, True) 的主要问题是 time_points > t。事实上,虽然 np.searchsortedO(log n) 时间内运行(ntime_points 的大小),但表达式 time_points > tO(n) 时间内计算并且明显是瓶颈。事实上,您不需要这个子表达式:i = np.searchsorted(time_points, t, 'right') 应该可以正确完成这项工作。请注意,最后一种情况不会在两个计时器之间执行任何操作,正如@DavidW 在评论中所解释的那样:它只是定义了一个 Cython 函数。