我可以在过滤 numpy 数组方面做得更好吗
Can I do better on filtering numpy array
我有一个有点人为的细胞化示例,我想要一个函数来:
- 接受一维
numpy
任意长度的数组 (~100'000 ÷ 1'000'000 np.float64
's)
- 对其进行一些过滤
- return 结果作为一个新的 [numpy?] 长度相同的数组
代码及分析如下:
%%cython -a
from libc.stdlib cimport malloc, free
from cython cimport boundscheck, wraparound
import numpy as np
@boundscheck(False)
@wraparound(False)
def func_memview(double[:] arr):
cdef:
int N = arr.shape[0], i
double *out_ptr = <double *> malloc(N * sizeof(double))
double[:] out = <double[:N]>out_ptr
for i in range(1, N):
if arr[i] > arr[i-1]:
out[i] = arr[i]
else:
out[i] = 0.
free(out_ptr)
return np.asarray(out)
我的问题是我可以用它做得更好吗?
正如 DavidW 指出的那样,您的代码在内存管理方面存在一些问题,最好直接使用 numpy 数组:
%%cython
from cython cimport boundscheck, wraparound
import numpy as np
@boundscheck(False)
@wraparound(False)
def func_memview_correct(double[:] arr):
cdef:
int N = arr.shape[0], i
double[:] out = np.empty(N)
for i in range(1, N):
if arr[i] > arr[i-1]:
out[i] = arr[i]
else:
out[i] = 0.0
return np.asarray(out)
它和有问题的原始版本差不多快:
import numpy as np
np.random.seed(0)
k= np.random.rand(5*10**7)
%timeit func_memview(k) # 413 ms ± 14.1 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
%timeit func_memview_correct(k) # 412 ms ± 15.5 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
问题是如何使这段代码更快?最明显的选项是
- 并行化。
- 使用 vectorization/SIMD 指令。
众所周知,很难确保 Cython 生成的 C 代码得到矢量化,例如参见 [=18=]。对于许多编译器来说,有必要使用连续内存视图来改善这种情况,即:
%%cython -c=/O3
from cython cimport boundscheck, wraparound
import numpy as np
@boundscheck(False)
@wraparound(False)
def func_memview_correct_cont(double[::1] arr): // <---- HERE
cdef:
int N = arr.shape[0], i
double[::1] out = np.empty(N) // <--- HERE
for i in range(1, N):
if arr[i] > arr[i-1]:
out[i] = arr[i]
else:
out[i] = 0.0
return np.asarray(out)
在我的机器上它并没有快多少
%timeit func_memview_correct_cont(k) # 402 ms ± 11.2 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
其他编译器可能会做得更好。但是,我经常看到 gcc 和 msvc 努力为典型的过滤代码生成最佳汇编程序(例如,参见 )。 Clang 在这方面做得更好,所以最简单的解决方案可能是使用 numba
:
import numba as nb
@nb.njit
def nb_func(arr):
N = arr.shape[0]
out = np.empty(N)
for i in range(1, N):
if arr[i] > arr[i-1]:
out[i] = arr[i]
else:
out[i] = 0.0
return out
比 cython 代码快了将近 3 倍:
%timeit nb_func(k) # 151 ms ± 2.87 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
使用 prange
很容易并行化 numba 版本,但胜利并不多:并行化版本在我的机器上运行 116 毫秒。
总结:对于此类任务,我的建议是使用 numba。使用 cython 比较棘手,最终性能将取决于后台使用的编译器。
我有一个有点人为的细胞化示例,我想要一个函数来:
- 接受一维
numpy
任意长度的数组 (~100'000 ÷ 1'000'000np.float64
's) - 对其进行一些过滤
- return 结果作为一个新的 [numpy?] 长度相同的数组
代码及分析如下:
%%cython -a
from libc.stdlib cimport malloc, free
from cython cimport boundscheck, wraparound
import numpy as np
@boundscheck(False)
@wraparound(False)
def func_memview(double[:] arr):
cdef:
int N = arr.shape[0], i
double *out_ptr = <double *> malloc(N * sizeof(double))
double[:] out = <double[:N]>out_ptr
for i in range(1, N):
if arr[i] > arr[i-1]:
out[i] = arr[i]
else:
out[i] = 0.
free(out_ptr)
return np.asarray(out)
我的问题是我可以用它做得更好吗?
正如 DavidW 指出的那样,您的代码在内存管理方面存在一些问题,最好直接使用 numpy 数组:
%%cython
from cython cimport boundscheck, wraparound
import numpy as np
@boundscheck(False)
@wraparound(False)
def func_memview_correct(double[:] arr):
cdef:
int N = arr.shape[0], i
double[:] out = np.empty(N)
for i in range(1, N):
if arr[i] > arr[i-1]:
out[i] = arr[i]
else:
out[i] = 0.0
return np.asarray(out)
它和有问题的原始版本差不多快:
import numpy as np
np.random.seed(0)
k= np.random.rand(5*10**7)
%timeit func_memview(k) # 413 ms ± 14.1 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
%timeit func_memview_correct(k) # 412 ms ± 15.5 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
问题是如何使这段代码更快?最明显的选项是
- 并行化。
- 使用 vectorization/SIMD 指令。
众所周知,很难确保 Cython 生成的 C 代码得到矢量化,例如参见 [=18=]。对于许多编译器来说,有必要使用连续内存视图来改善这种情况,即:
%%cython -c=/O3
from cython cimport boundscheck, wraparound
import numpy as np
@boundscheck(False)
@wraparound(False)
def func_memview_correct_cont(double[::1] arr): // <---- HERE
cdef:
int N = arr.shape[0], i
double[::1] out = np.empty(N) // <--- HERE
for i in range(1, N):
if arr[i] > arr[i-1]:
out[i] = arr[i]
else:
out[i] = 0.0
return np.asarray(out)
在我的机器上它并没有快多少
%timeit func_memview_correct_cont(k) # 402 ms ± 11.2 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
其他编译器可能会做得更好。但是,我经常看到 gcc 和 msvc 努力为典型的过滤代码生成最佳汇编程序(例如,参见 numba
:
import numba as nb
@nb.njit
def nb_func(arr):
N = arr.shape[0]
out = np.empty(N)
for i in range(1, N):
if arr[i] > arr[i-1]:
out[i] = arr[i]
else:
out[i] = 0.0
return out
比 cython 代码快了将近 3 倍:
%timeit nb_func(k) # 151 ms ± 2.87 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
使用 prange
很容易并行化 numba 版本,但胜利并不多:并行化版本在我的机器上运行 116 毫秒。
总结:对于此类任务,我的建议是使用 numba。使用 cython 比较棘手,最终性能将取决于后台使用的编译器。