在 cython 中修剪一个 numpy 数组

Trimming a numpy array in cython

目前我有以下 cython 函数,修改一个用零填充的 numpy 数组的条目以求和非零值。在我 return 数组之前,我想 trim 它并删除所有非零条目。目前,我使用 numpy 函数 myarray = myarray[~np.all(myarray == 0, axis=1)] 来这样做。我想知道是否有(通常)更快的方法来使用 Cython/C 函数而不是依赖 python/numpy 来执行此操作。这是我脚本中最后的 pythonic 交互之一(使用 to %%cython -a 检查)。但我真的不知道如何处理这个问题。一般来说,我事先不知道最终数组中非零元素的数量。

cdef func():
   np.ndarray[np.float64_t, ndim=2] myarray = np.zeros((lenpropen, 6)) 
   """
   computations
   """

   myarray = myarray[~np.all(myarray == 0, axis=1)]
   return myarray

感谢@Jérôme Richard 的评论。基于此(如果我的理解是正确的)我尝试实现擦除-删除习惯用法。下面给出了示例代码。

myarray = np.zeros((5000,6))
myarray[2] = [1,1,1,1,1,1]
@cython.boundscheck(False)  # Deactivate bounds checking                                                                  
@cython.wraparound(False)   # Deactivate negative indexing.                                                               
@cython.cdivision(True)     # Deactivate division by 0 checking.
cdef erase_remove( np.ndarray[np.float64_t, ndim=2] myarray):
    cdef int idx 
    cdef int cursor = 0
    cdef int length_arr = 5000
    for idx in range(5000):
    
        if myarray[idx,0]!=0 and myarray[idx,1]!=0 and myarray[idx,2]!=0 and myarray[idx,3]!=0 and myarray[idx,4]!=0 and  myarray[idx,5]!=0:
            myarray[cursor,0] = myarray[idx,0]
            myarray[cursor,1] = myarray[idx,1]
            myarray[cursor,2] = myarray[idx,2]
            myarray[cursor,3] = myarray[idx,3]
            myarray[cursor,4] = myarray[idx,4]
            myarray[cursor,5] = myarray[idx,5]
            cursor = cursor +1
        else:
            continue
    return  myarray[0:cursor]    
start = timer()
myarray= erase_remove(myarray)
end = timer()
print("final", myarray)
print("time", end-start)

这会产生输出

final [[1. 1. 1. 1. 1. 1.]]
time 1.1235475540161133e-05

相比
myarray = np.zeros((5000,6))
print(myarray)

myarray[2] = [1,1,1,1,1,1]
print(myarray)
start = timer()
myarray = myarray[~np.all(myarray == 0, axis=1)]
end = timer()
print(myarray)
print("time", end-start)

产生输出

[[1. 1. 1. 1. 1. 1.]]
time 0.0006445050239562988

如果最高维度总是包含像 6 这样的少量元素,那么你的代码不是最好的。

首先,myarray == 0np.all~ 创建了 临时数组 ,这会引入一些额外的开销,因为它们需要被写入并回读。开销取决于临时数组的 this,最大的开销是 myarray == 0.

此外,Numpy 调用会执行一些 Cython 无法删除的不需要的检查。这些检查引入了恒定的时间开销。因此,它对于小输入数组可能相当大,但对于大输入数组则不然。

此外,如果 np.all 的代码知道最后一个维度的确切大小,那么它的代码会更快,而这里不是这种情况。实际上,np.all 的循环理论上可以 展开 ,因为最后一个维度很小。不幸的是,Cython 不优化 Numpy 调用 并且 Numpy 是针对可变输入大小编译的,因此 在 compile-time 处未知。

最后,如果lenpropen 很大,计算可以并行化(否则这不会更快,实际上可能会更慢)。但是,请注意,并行实现需要分两步完成计算:np.all(myarray == 0, axis=1) 需要并行计算,然后您可以创建结果数组并通过并行计算 myarray[~result] 来写入它。按顺序,您可以通过 过滤行 in-place 直接覆盖 myarray,然后生成过滤后的行的视图。这种模式被称为 erase-remove idiom。请注意,这假设数组是 连续的.

总而言之,一个更快的实现包括编写 2 个嵌套循环在 myarray 上迭代,最里面的迭代次数为常数。关于 lenpropen 的大小,您可以使用基于 erase-remove 习惯用法的顺序 in-place 实现,或者使用两个步骤(和一个临时数组)的并行 out-of-place 实现.