加速cython代码
Speed up cython code
我写了一个 python 代码来管理大量数据,因此需要花费很多时间。所以,我发现了 Cython 并开始更改我的代码。
基本上,我所做的就是更改函数的声明(cdef 类型名称(变量类型的参数)),声明 cdef 变量及其类型,并声明 cdef 类。
我正在用 eclipse 编写所有 .pyx
,我正在使用命令 python setup.py build_ext --inplace
进行编译,并使用 eclipse 运行 编译它。
我的问题是 python 与 cython 速度比较,没有任何区别。
我运行命令cython -a <file>
生成了一个html文件,有很多黄线
不知道是不是我做错了,我应该加点别的,我也不知道怎么把这些黄线删掉。
我只是粘贴了一些代码行,那是我想加快速度的部分,因为代码很长。
main.pyx
'''there are a lot of ndarray objects stored in a file and in this step I get each of them until there are no more items '''
cdef ReadWavePoints (WavePointManagement wavePointManagement, ColumnManagement columnManagement):
cdef int runReadWavePoints
wavePointManagement.OpenWavePointFileLoad(wavePointsFile)
runReadWavePoints = 1
while runReadWavePoints == 1:
try:
wavePointManagement.LoadWavePointFile()
wavePointManagement.RoundCoordinates()
wavePointManagement.SortWavePointList()
GroupColumnsVoxels(wavePointManagement.GetWavePointList(), columnManagement)
except:
wavePointManagement.CloseWavePointFile()
columnManagement.CloseWriteColumnFile()
break
'''I check which points are in the same XYZ (voxel) and in the same XY (column)'''
cdef GroupColumnsVoxels (object wavePointList, ColumnManagement columnManagement):
cdef int indexWavePointRef, indexWavePoint
cdef int saved
cdef double voxelValue
cdef int sizeWavePointList
sizeWavePointList = len(wavePointList)
indexWavePointRef = 0
while indexWavePointRef < sizeWavePointList - 1:
saved = 0
voxelValue = (wavePointList[indexWavePointRef]).GetValue()
for indexWavePoint in xrange(indexWavePointRef + 1, len(wavePointList)):
if (wavePointList[indexWavePointRef]).GetX() == (wavePointList[indexWavePoint]).GetX() and (wavePointList[indexWavePointRef]).GetY() == (wavePointList[indexWavePoint]).GetY():
if (wavePointList[indexWavePointRef]).GetZ() == (wavePointList[indexWavePoint]).GetZ():
if voxelValue < (wavePointList[indexWavePoint]).GetValue():
voxelValue = (wavePointList[indexWavePoint]).GetValue()
else:
saved = 1
CheckVoxel((wavePointList[indexWavePointRef]).GetX(), (wavePointList[indexWavePointRef]).GetY(), (wavePointList[indexWavePointRef]).GetZ(), voxelValue)
indexWavePointRef = indexWavePoint
if indexWavePointRef == sizeWavePointList - 1:
CheckVoxel((wavePointList[indexWavePointRef]).GetX(), (wavePointList[indexWavePointRef]).GetY(), (wavePointList[indexWavePointRef]).GetZ(), (wavePointList[indexWavePointRef]).GetValue())
break
else:
saved = 1
CheckVoxel((wavePointList[indexWavePointRef]).GetX(), (wavePointList[indexWavePointRef]).GetY(), (wavePointList[indexWavePointRef]).GetZ(), voxelValue)
columnObject = columnInstance.Column((wavePointList[indexWavePointRef]).GetX(), (wavePointList[indexWavePointRef]).GetY())
columnManagement.AddColumn(columnObject)
MaximumHeightColumn((wavePointList[indexWavePointRef]).GetX(), (wavePointList[indexWavePointRef]).GetY(), (wavePointList[indexWavePointRef]).GetZ())
indexWavePointRef = indexWavePoint
break
if saved == 0:
CheckVoxel((wavePointList[indexWavePointRef]).GetX(), (wavePointList[indexWavePointRef]).GetY(), (wavePointList[indexWavePointRef]).GetZ(), voxelValue)
indexWavePointRef = indexWavePoint
columnObject = columnInstance.Column((wavePointList[indexWavePointRef]).GetX(), (wavePointList[indexWavePointRef]).GetY())
columnManagement.AddColumn(columnObject)
MaximumHeightColumn((wavePointList[indexWavePointRef]).GetX(), (wavePointList[indexWavePointRef]).GetY(), (wavePointList[indexWavePointRef]).GetZ())
'''I check if the data stored in a voxel is lower than the new one; if its the case, I store it'''
cdef CheckVoxel (double X, double Y, double Z, double newValue):
cdef object bandVoxel, structvalCheckVoxel, out_str
cdef tuple valueCheckVoxel
bandVoxel = datasetVoxels.GetRasterBand(int(math.floor(Z/0.3))+1)
structvalCheckVoxel = bandVoxel.ReadRaster(int(math.floor((X-Xmin)/0.25)), int(math.floor((Ymax-Y)/0.25)), 1, 1, buf_type=gdal.GDT_Float32)
valueCheckVoxel = struct.unpack('f', structvalCheckVoxel)
if newValue > valueCheckVoxel[0]:
out_str = struct.pack('f', newValue)
bandVoxel.WriteRaster(int(math.floor((X-Xmin)/0.25)), int(math.floor((Ymax-Y)/0.25)), 1, 1, out_str)
'''I check if this point has the highest Z and I store this information'''
cdef MaximumHeightColumn(double X, double Y, double newZ):
cdef object bandMetricMaximumHeightColumn, structvalMaximumHeightColumn, out_strMaximumHeightColumn
cdef tuple valueMaximumHeightColumn
bandMetricMaximumHeightColumn = datasetMetrics.GetRasterBand(10)
structvalMaximumHeightColumn = bandMetricMaximumHeightColumn.ReadRaster(int(math.floor((X-Xmin)/0.25)), int(math.floor((Ymax-Y)/0.25)), 1, 1, buf_type=gdal.GDT_Float32)
valueMaximumHeightColumn = struct.unpack('f', structvalMaximumHeightColumn)
if newZ > round(valueMaximumHeightColumn[0], 1):
out_strMaximumHeightColumn = struct.pack('f', newZ)
bandMetricMaximumHeightColumn.WriteRaster(int(math.floor((X-Xmin)/0.25)), int(math.floor((Ymax-Y)/0.25)), 1, 1, out_strMaximumHeightColumn)
WavePointManagement.pyx
'''this class serializes, rounds and sorts the points of each ndarray'''
import cPickle as pickle
import numpy as np
cimport numpy as np
import math
cdef class WavePointManagement(object):
'''
This class manages all the points extracted from the waveform
'''
cdef object fileObject, wavePointList
__slots__ = ('wavePointList', 'fileObject')
def __cinit__(self):
'''
Constructor
'''
self.fileObject = None
self.wavePointList = np.array([])
cdef object GetWavePointList(self):
return self.wavePointList
cdef void OpenWavePointFileLoad (self, object fileName):
self.fileObject = file(fileName, 'rb')
cdef void LoadWavePointFile (self):
self.wavePointList = None
self.wavePointList = pickle.load(self.fileObject)
cdef void SortWavePointList (self):
self.wavePointList = sorted(self.wavePointList, key=lambda k: (k.x, k.y, k.z))
cdef void RoundCoordinates (self):
cdef int indexPointObject, sizeWavePointList
for pointObject in self.GetWavePointList():
pointObject.SetX(round(math.floor(pointObject.GetX()/0.25)*0.25, 2))
pointObject.SetY(round(math.ceil(pointObject.GetY()/0.25)*0.25, 2))
pointObject.SetZ(round(math.floor(pointObject.GetZ()/0.3)*0.3, 1))
cdef void CloseWavePointFile(self):
self.fileObject.close()
setup.py
from distutils.core import setup
from distutils.extension import Extension
from Cython.Distutils import build_ext
import numpy
ext = Extension("main", ["main.pyx"], include_dirs = [numpy.get_include()])
setup (ext_modules=[ext],
cmdclass = {'build_ext' : build_ext}
)
test_cython.py
'''this is the file I run with eclipse after compiling'''
from main import main
main()
我怎样才能加快这段代码的速度?
您的代码在使用 numpy 数组和列表之间来回跳转。因此,cython 生成的代码几乎没有区别。
下面的代码生成一个python列表,关键函数也是一个纯python函数。
self.wavePointList = sorted(self.wavePointList, key=lambda k: (k.x, k.y, k.z))
您需要使用 ndarray.sort
(如果您不想就地排序,则可以使用 numpy.sort
)。为此,您还需要更改对象在数组中的存储方式。也就是说,您将需要使用 structured array. See numpy.sort
作为有关如何对结构化数组进行排序的示例——尤其是页面上的最后两个示例。
将数据存储在 numpy 数组中后,您需要告诉 cython 数据是如何存储在数组中的。这包括提供类型信息和数组的维度。 This page 提供了有关如何高效使用 numpy 数组的更多信息。
创建和排序结构化数组的示例:
import numpy as np
cimport numpy as np
DTYPE = [('name', 'S10'), ('height', np.float64), ('age', np.int32)]
cdef packed struct Person:
char name[10]
np.float64_t height
np.int32_t age
ctypedef Person DTYPE_t
def create_array():
values = [('Arthur', 1.8, 41), ('Lancelot', 1.9, 38),
('Galahad', 1.7, 38)]
return np.array(values, dtype=DTYPE)
cpdef sort_by_age_then_height(np.ndarray[DTYPE_t, ndim=1] arr):
arr.sort(order=['age', 'height'])
最后,您需要将代码从使用 python 方法转换为使用标准 c 库方法以进一步加快速度。下面是一个使用 RoundCoordinates
的例子。 ``cpdef` 意味着该函数也通过包装函数公开给 python。
cimport cython
cimport numpy as np
from libc.math cimport floor, ceil, round
import numpy as np
DTYPE = [('x', np.float64), ('y', np.float64), ('z', np.float64)]
cdef packed struct Point3D:
np.float64_t x, y, z
ctypedef Point3D DTYPE_t
# Caution should be used when turning the bounds check off as it can lead to undefined
# behaviour if you use an invalid index.
@cython.boundscheck(False)
cpdef RoundCoordinates_cy(np.ndarray[DTYPE_t] pointlist):
cdef int i
cdef DTYPE_t point
for i in range(len(pointlist)): # this line is optimised into a c loop
point = pointlist[i] # creates a copy of the point
point.x = round(floor(point.x/0.25)*2.5) / 10
point.y = round(ceil(point.y/0.25)*2.5) / 10
point.z = round(floor(point.z/0.3)*3) / 10
pointlist[i] = point # overwrites the old point data with the new data
最后,在重写整个代码库之前,您应该分析代码以查看程序花费大部分时间的函数并在优化其他函数之前优化这些函数。
我写了一个 python 代码来管理大量数据,因此需要花费很多时间。所以,我发现了 Cython 并开始更改我的代码。
基本上,我所做的就是更改函数的声明(cdef 类型名称(变量类型的参数)),声明 cdef 变量及其类型,并声明 cdef 类。
我正在用 eclipse 编写所有 .pyx
,我正在使用命令 python setup.py build_ext --inplace
进行编译,并使用 eclipse 运行 编译它。
我的问题是 python 与 cython 速度比较,没有任何区别。
我运行命令cython -a <file>
生成了一个html文件,有很多黄线
不知道是不是我做错了,我应该加点别的,我也不知道怎么把这些黄线删掉。
我只是粘贴了一些代码行,那是我想加快速度的部分,因为代码很长。
main.pyx
'''there are a lot of ndarray objects stored in a file and in this step I get each of them until there are no more items '''
cdef ReadWavePoints (WavePointManagement wavePointManagement, ColumnManagement columnManagement):
cdef int runReadWavePoints
wavePointManagement.OpenWavePointFileLoad(wavePointsFile)
runReadWavePoints = 1
while runReadWavePoints == 1:
try:
wavePointManagement.LoadWavePointFile()
wavePointManagement.RoundCoordinates()
wavePointManagement.SortWavePointList()
GroupColumnsVoxels(wavePointManagement.GetWavePointList(), columnManagement)
except:
wavePointManagement.CloseWavePointFile()
columnManagement.CloseWriteColumnFile()
break
'''I check which points are in the same XYZ (voxel) and in the same XY (column)'''
cdef GroupColumnsVoxels (object wavePointList, ColumnManagement columnManagement):
cdef int indexWavePointRef, indexWavePoint
cdef int saved
cdef double voxelValue
cdef int sizeWavePointList
sizeWavePointList = len(wavePointList)
indexWavePointRef = 0
while indexWavePointRef < sizeWavePointList - 1:
saved = 0
voxelValue = (wavePointList[indexWavePointRef]).GetValue()
for indexWavePoint in xrange(indexWavePointRef + 1, len(wavePointList)):
if (wavePointList[indexWavePointRef]).GetX() == (wavePointList[indexWavePoint]).GetX() and (wavePointList[indexWavePointRef]).GetY() == (wavePointList[indexWavePoint]).GetY():
if (wavePointList[indexWavePointRef]).GetZ() == (wavePointList[indexWavePoint]).GetZ():
if voxelValue < (wavePointList[indexWavePoint]).GetValue():
voxelValue = (wavePointList[indexWavePoint]).GetValue()
else:
saved = 1
CheckVoxel((wavePointList[indexWavePointRef]).GetX(), (wavePointList[indexWavePointRef]).GetY(), (wavePointList[indexWavePointRef]).GetZ(), voxelValue)
indexWavePointRef = indexWavePoint
if indexWavePointRef == sizeWavePointList - 1:
CheckVoxel((wavePointList[indexWavePointRef]).GetX(), (wavePointList[indexWavePointRef]).GetY(), (wavePointList[indexWavePointRef]).GetZ(), (wavePointList[indexWavePointRef]).GetValue())
break
else:
saved = 1
CheckVoxel((wavePointList[indexWavePointRef]).GetX(), (wavePointList[indexWavePointRef]).GetY(), (wavePointList[indexWavePointRef]).GetZ(), voxelValue)
columnObject = columnInstance.Column((wavePointList[indexWavePointRef]).GetX(), (wavePointList[indexWavePointRef]).GetY())
columnManagement.AddColumn(columnObject)
MaximumHeightColumn((wavePointList[indexWavePointRef]).GetX(), (wavePointList[indexWavePointRef]).GetY(), (wavePointList[indexWavePointRef]).GetZ())
indexWavePointRef = indexWavePoint
break
if saved == 0:
CheckVoxel((wavePointList[indexWavePointRef]).GetX(), (wavePointList[indexWavePointRef]).GetY(), (wavePointList[indexWavePointRef]).GetZ(), voxelValue)
indexWavePointRef = indexWavePoint
columnObject = columnInstance.Column((wavePointList[indexWavePointRef]).GetX(), (wavePointList[indexWavePointRef]).GetY())
columnManagement.AddColumn(columnObject)
MaximumHeightColumn((wavePointList[indexWavePointRef]).GetX(), (wavePointList[indexWavePointRef]).GetY(), (wavePointList[indexWavePointRef]).GetZ())
'''I check if the data stored in a voxel is lower than the new one; if its the case, I store it'''
cdef CheckVoxel (double X, double Y, double Z, double newValue):
cdef object bandVoxel, structvalCheckVoxel, out_str
cdef tuple valueCheckVoxel
bandVoxel = datasetVoxels.GetRasterBand(int(math.floor(Z/0.3))+1)
structvalCheckVoxel = bandVoxel.ReadRaster(int(math.floor((X-Xmin)/0.25)), int(math.floor((Ymax-Y)/0.25)), 1, 1, buf_type=gdal.GDT_Float32)
valueCheckVoxel = struct.unpack('f', structvalCheckVoxel)
if newValue > valueCheckVoxel[0]:
out_str = struct.pack('f', newValue)
bandVoxel.WriteRaster(int(math.floor((X-Xmin)/0.25)), int(math.floor((Ymax-Y)/0.25)), 1, 1, out_str)
'''I check if this point has the highest Z and I store this information'''
cdef MaximumHeightColumn(double X, double Y, double newZ):
cdef object bandMetricMaximumHeightColumn, structvalMaximumHeightColumn, out_strMaximumHeightColumn
cdef tuple valueMaximumHeightColumn
bandMetricMaximumHeightColumn = datasetMetrics.GetRasterBand(10)
structvalMaximumHeightColumn = bandMetricMaximumHeightColumn.ReadRaster(int(math.floor((X-Xmin)/0.25)), int(math.floor((Ymax-Y)/0.25)), 1, 1, buf_type=gdal.GDT_Float32)
valueMaximumHeightColumn = struct.unpack('f', structvalMaximumHeightColumn)
if newZ > round(valueMaximumHeightColumn[0], 1):
out_strMaximumHeightColumn = struct.pack('f', newZ)
bandMetricMaximumHeightColumn.WriteRaster(int(math.floor((X-Xmin)/0.25)), int(math.floor((Ymax-Y)/0.25)), 1, 1, out_strMaximumHeightColumn)
WavePointManagement.pyx
'''this class serializes, rounds and sorts the points of each ndarray'''
import cPickle as pickle
import numpy as np
cimport numpy as np
import math
cdef class WavePointManagement(object):
'''
This class manages all the points extracted from the waveform
'''
cdef object fileObject, wavePointList
__slots__ = ('wavePointList', 'fileObject')
def __cinit__(self):
'''
Constructor
'''
self.fileObject = None
self.wavePointList = np.array([])
cdef object GetWavePointList(self):
return self.wavePointList
cdef void OpenWavePointFileLoad (self, object fileName):
self.fileObject = file(fileName, 'rb')
cdef void LoadWavePointFile (self):
self.wavePointList = None
self.wavePointList = pickle.load(self.fileObject)
cdef void SortWavePointList (self):
self.wavePointList = sorted(self.wavePointList, key=lambda k: (k.x, k.y, k.z))
cdef void RoundCoordinates (self):
cdef int indexPointObject, sizeWavePointList
for pointObject in self.GetWavePointList():
pointObject.SetX(round(math.floor(pointObject.GetX()/0.25)*0.25, 2))
pointObject.SetY(round(math.ceil(pointObject.GetY()/0.25)*0.25, 2))
pointObject.SetZ(round(math.floor(pointObject.GetZ()/0.3)*0.3, 1))
cdef void CloseWavePointFile(self):
self.fileObject.close()
setup.py
from distutils.core import setup
from distutils.extension import Extension
from Cython.Distutils import build_ext
import numpy
ext = Extension("main", ["main.pyx"], include_dirs = [numpy.get_include()])
setup (ext_modules=[ext],
cmdclass = {'build_ext' : build_ext}
)
test_cython.py
'''this is the file I run with eclipse after compiling'''
from main import main
main()
我怎样才能加快这段代码的速度?
您的代码在使用 numpy 数组和列表之间来回跳转。因此,cython 生成的代码几乎没有区别。
下面的代码生成一个python列表,关键函数也是一个纯python函数。
self.wavePointList = sorted(self.wavePointList, key=lambda k: (k.x, k.y, k.z))
您需要使用 ndarray.sort
(如果您不想就地排序,则可以使用 numpy.sort
)。为此,您还需要更改对象在数组中的存储方式。也就是说,您将需要使用 structured array. See numpy.sort
作为有关如何对结构化数组进行排序的示例——尤其是页面上的最后两个示例。
将数据存储在 numpy 数组中后,您需要告诉 cython 数据是如何存储在数组中的。这包括提供类型信息和数组的维度。 This page 提供了有关如何高效使用 numpy 数组的更多信息。
创建和排序结构化数组的示例:
import numpy as np
cimport numpy as np
DTYPE = [('name', 'S10'), ('height', np.float64), ('age', np.int32)]
cdef packed struct Person:
char name[10]
np.float64_t height
np.int32_t age
ctypedef Person DTYPE_t
def create_array():
values = [('Arthur', 1.8, 41), ('Lancelot', 1.9, 38),
('Galahad', 1.7, 38)]
return np.array(values, dtype=DTYPE)
cpdef sort_by_age_then_height(np.ndarray[DTYPE_t, ndim=1] arr):
arr.sort(order=['age', 'height'])
最后,您需要将代码从使用 python 方法转换为使用标准 c 库方法以进一步加快速度。下面是一个使用 RoundCoordinates
的例子。 ``cpdef` 意味着该函数也通过包装函数公开给 python。
cimport cython
cimport numpy as np
from libc.math cimport floor, ceil, round
import numpy as np
DTYPE = [('x', np.float64), ('y', np.float64), ('z', np.float64)]
cdef packed struct Point3D:
np.float64_t x, y, z
ctypedef Point3D DTYPE_t
# Caution should be used when turning the bounds check off as it can lead to undefined
# behaviour if you use an invalid index.
@cython.boundscheck(False)
cpdef RoundCoordinates_cy(np.ndarray[DTYPE_t] pointlist):
cdef int i
cdef DTYPE_t point
for i in range(len(pointlist)): # this line is optimised into a c loop
point = pointlist[i] # creates a copy of the point
point.x = round(floor(point.x/0.25)*2.5) / 10
point.y = round(ceil(point.y/0.25)*2.5) / 10
point.z = round(floor(point.z/0.3)*3) / 10
pointlist[i] = point # overwrites the old point data with the new data
最后,在重写整个代码库之前,您应该分析代码以查看程序花费大部分时间的函数并在优化其他函数之前优化这些函数。