我可以在过滤 numpy 数组方面做得更好吗

Can I do better on filtering numpy array

我有一个有点人为的细胞化示例,我想要一个函数来:

  1. 接受一维 numpy 任意长度的数组 (~100'000 ÷ 1'000'000 np.float64's)
  2. 对其进行一些过滤
  3. return 结果作为一个新的 [numpy?] 长度相同的数组

代码及分析如下:

%%cython -a

from libc.stdlib cimport malloc, free
from cython cimport boundscheck, wraparound
import numpy as np

@boundscheck(False)
@wraparound(False)
def func_memview(double[:] arr):
    cdef:
        int N = arr.shape[0], i
        double *out_ptr = <double *> malloc(N * sizeof(double))
        double[:] out = <double[:N]>out_ptr
    for i in range(1, N):
        if arr[i] > arr[i-1]:
            out[i] = arr[i]
        else:
            out[i] = 0.
    free(out_ptr)
    return np.asarray(out)

我的问题是我可以用它做得更好吗?

正如 DavidW 指出的那样,您的代码在内存管理方面存在一些问题,最好直接使用 numpy 数组:

%%cython

from cython cimport boundscheck, wraparound
import numpy as np

@boundscheck(False)
@wraparound(False)
def func_memview_correct(double[:] arr):
    cdef:
        int N = arr.shape[0], i
        double[:] out = np.empty(N)
    for i in range(1, N):
        if arr[i] > arr[i-1]:
            out[i] = arr[i]
        else:
            out[i] = 0.0
    return np.asarray(out)

它和有问题的原始版本差不多快:

import numpy as np
np.random.seed(0)
k= np.random.rand(5*10**7)

%timeit func_memview(k)          # 413 ms ± 14.1 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
%timeit func_memview_correct(k)  # 412 ms ± 15.5 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

问题是如何使这段代码更快?最明显的选项是

  1. 并行化。
  2. 使用 vectorization/SIMD 指令。

众所周知,很难确保 Cython 生成的 C 代码得到矢量化,例如参见 [​​=18=]。对于许多编译器来说,有必要使用连续内存视图来改善这种情况,即:

%%cython -c=/O3

from cython cimport boundscheck, wraparound
import numpy as np

@boundscheck(False)
@wraparound(False)
def func_memview_correct_cont(double[::1] arr):  // <---- HERE
    cdef:
        int N = arr.shape[0], i
        double[::1] out = np.empty(N)   // <--- HERE
    for i in range(1, N):
        if arr[i] > arr[i-1]:
            out[i] = arr[i]
        else:
            out[i] = 0.0
    return np.asarray(out)

在我的机器上它并没有快多少

%timeit func_memview_correct_cont(k)  # 402 ms ± 11.2 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

其他编译器可能会做得更好。但是,我经常看到 gcc 和 msvc 努力为典型的过滤代码生成最佳汇编程序(例如,参见 )。 Clang 在这方面做得更好,所以最简单的解决方案可能是使用 numba:

import numba as nb

@nb.njit
def nb_func(arr):
    N = arr.shape[0]
    out = np.empty(N)
    for i in range(1, N):
        if arr[i] > arr[i-1]:
            out[i] = arr[i]
        else:
            out[i] = 0.0
    return out

比 cython 代码快了将近 3 倍:

%timeit nb_func(k)  # 151 ms ± 2.87 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)

使用 prange 很容易并行化 numba 版本,但胜利并不多:并行化版本在我的机器上运行 116 毫秒。


总结:对于此类任务,我的建议是使用 numba。使用 cython 比较棘手,最终性能将取决于后台使用的编译器。