Cython:相同的速度定义 numpy 的静态类型
Cython: same speed defining the static type of numpy
我必须使用 python 对微分方程进行数值求解。基本上我有两个不同的代码。一个负责读取问题的初始条件,一个负责计算所有的帐户。我想使用 cython 优化第二个。
当我将常量的静态类型(dz,dt,i,k,j ..)定义为floating或int时,我减少了四分之一的计算时间。现在,当我为 numpy 数组定义静态类型时,我没有任何改进。
这是我的代码(.pyx):
import numpy as np
cimport numpy as np
DTYPE = np.int
ctypedef np.int_t DTYPE_t
def explicit_cython(np.ndarray u, float kappa, float dt, float dz, np.ndarray term_const, unsigned int nz, plot_time):
'''Cython version of explicit method'''
#Defining C types
cdef unsigned int i, k, j
cdef unsigned int len_plot = len(plot_time) - 1
cdef float lamnda = kappa*dt/dz**2
u_out = []
u_out.append(u.copy())
for i in range(len_plot):
for k in range(plot_time[i], plot_time[i+1]):
un = u.copy()
for j in range(1, nz-1):
u[j] = un[j] + lamnda*(un[j+1] - 2*un[j] + un[j-1]) + term_const[j]
u_out.append(u.copy())
return u_out
这是我用来编译的设置。
from distutils.core import setup
from distutils.extension import Extension
from Cython.Build import cythonize
extensions=[Extension("explicit_cython2",["explicit_cython2.pyx"])]
setup(
ext_modules = cythonize(extensions)
)
当我制作 python3 setup.py build_ext --inplace
时,有这个警告:
In file included from /usr/include/numpy/ndarraytypes.h:1728:0,
from /usr/include/numpy/ndarrayobject.h:17,
from /usr/include/numpy/arrayobject.h:15,
from explicit_cython2.c:258:
/usr/include/numpy/npy_deprecated_api.h:11:2: warning: #warning "Using deprecated NumPy API, disable it by #defining NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION" [-Wcpp]
#warning "Using deprecated NumPy API, disable it by #defining NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION"
为什么没有通过定义numpy的静态类型来赢得速度?为什么我有这个警告?谢谢!
PD。我在 LMDE
中使用 python 3.4 和 Anaconda
A) 除非您可以定义 numpy 数组的维度和内部数据类型,否则您可能不会得到任何好处
def explicit_cython(np.ndarray[np.float_t,ndim=2],...
B) 我认为已弃用的警告是说新的更好的界面是类型化内存视图 http://docs.cython.org/src/userguide/memoryviews.html。如果您不想使用这些,请忽略它。
C) 你可能会失去很多速度复制的东西,你会立即覆盖每一步,如果你能做到 np.zeros(n.shape)
你可能会有所收获。 (或者甚至只是跳过内部 for k
循环中的副本)。
D) 循环的主要内容可以被矢量化,无论如何都要避免使用 Cython。
我必须使用 python 对微分方程进行数值求解。基本上我有两个不同的代码。一个负责读取问题的初始条件,一个负责计算所有的帐户。我想使用 cython 优化第二个。 当我将常量的静态类型(dz,dt,i,k,j ..)定义为floating或int时,我减少了四分之一的计算时间。现在,当我为 numpy 数组定义静态类型时,我没有任何改进。
这是我的代码(.pyx):
import numpy as np
cimport numpy as np
DTYPE = np.int
ctypedef np.int_t DTYPE_t
def explicit_cython(np.ndarray u, float kappa, float dt, float dz, np.ndarray term_const, unsigned int nz, plot_time):
'''Cython version of explicit method'''
#Defining C types
cdef unsigned int i, k, j
cdef unsigned int len_plot = len(plot_time) - 1
cdef float lamnda = kappa*dt/dz**2
u_out = []
u_out.append(u.copy())
for i in range(len_plot):
for k in range(plot_time[i], plot_time[i+1]):
un = u.copy()
for j in range(1, nz-1):
u[j] = un[j] + lamnda*(un[j+1] - 2*un[j] + un[j-1]) + term_const[j]
u_out.append(u.copy())
return u_out
这是我用来编译的设置。
from distutils.core import setup
from distutils.extension import Extension
from Cython.Build import cythonize
extensions=[Extension("explicit_cython2",["explicit_cython2.pyx"])]
setup(
ext_modules = cythonize(extensions)
)
当我制作 python3 setup.py build_ext --inplace
时,有这个警告:
In file included from /usr/include/numpy/ndarraytypes.h:1728:0,
from /usr/include/numpy/ndarrayobject.h:17,
from /usr/include/numpy/arrayobject.h:15,
from explicit_cython2.c:258:
/usr/include/numpy/npy_deprecated_api.h:11:2: warning: #warning "Using deprecated NumPy API, disable it by #defining NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION" [-Wcpp]
#warning "Using deprecated NumPy API, disable it by #defining NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION"
为什么没有通过定义numpy的静态类型来赢得速度?为什么我有这个警告?谢谢!
PD。我在 LMDE
中使用 python 3.4 和 AnacondaA) 除非您可以定义 numpy 数组的维度和内部数据类型,否则您可能不会得到任何好处
def explicit_cython(np.ndarray[np.float_t,ndim=2],...
B) 我认为已弃用的警告是说新的更好的界面是类型化内存视图 http://docs.cython.org/src/userguide/memoryviews.html。如果您不想使用这些,请忽略它。
C) 你可能会失去很多速度复制的东西,你会立即覆盖每一步,如果你能做到 np.zeros(n.shape)
你可能会有所收获。 (或者甚至只是跳过内部 for k
循环中的副本)。
D) 循环的主要内容可以被矢量化,无论如何都要避免使用 Cython。