加速cython代码

Speed up cython code

我写了一个 python 代码来管理大量数据,因此需要花费很多时间。所以,我发现了 Cython 并开始更改我的代码。

基本上,我所做的就是更改函数的声明(cdef 类型名称(变量类型的参数)),声明 cdef 变量及其类型,并声明 cdef 类。 我正在用 eclipse 编写所有 .pyx,我正在使用命令 python setup.py build_ext --inplace 进行编译,并使用 eclipse 运行 编译它。

我的问题是 python 与 cython 速度比较,没有任何区别。

我运行命令cython -a <file>生成了一个html文件,有很多黄线

不知道是不是我做错了,我应该加点别的,我也不知道怎么把这些黄线删掉。

我只是粘贴了一些代码行,那是我想加快速度的部分,因为代码很长。


main.pyx

'''there are a lot of ndarray objects stored in a file and in this step I get each of them until there are no more items '''
cdef ReadWavePoints (WavePointManagement wavePointManagement, ColumnManagement columnManagement):
        cdef int runReadWavePoints

    wavePointManagement.OpenWavePointFileLoad(wavePointsFile)
    runReadWavePoints = 1

    while runReadWavePoints == 1:
        try:
            wavePointManagement.LoadWavePointFile()
            wavePointManagement.RoundCoordinates()
            wavePointManagement.SortWavePointList()
            GroupColumnsVoxels(wavePointManagement.GetWavePointList(), columnManagement)
        except:
            wavePointManagement.CloseWavePointFile()
            columnManagement.CloseWriteColumnFile()
            break

'''I check which points are in the same XYZ (voxel) and in the same XY (column)'''

cdef GroupColumnsVoxels (object wavePointList, ColumnManagement columnManagement):
    cdef int indexWavePointRef, indexWavePoint
    cdef int saved
    cdef double voxelValue
    cdef int sizeWavePointList
    
    sizeWavePointList = len(wavePointList)
    
    indexWavePointRef = 0

    while indexWavePointRef < sizeWavePointList - 1:
        saved = 0
        voxelValue = (wavePointList[indexWavePointRef]).GetValue()
        for indexWavePoint in xrange(indexWavePointRef + 1, len(wavePointList)):
            if (wavePointList[indexWavePointRef]).GetX() == (wavePointList[indexWavePoint]).GetX() and (wavePointList[indexWavePointRef]).GetY() == (wavePointList[indexWavePoint]).GetY():
                if (wavePointList[indexWavePointRef]).GetZ() == (wavePointList[indexWavePoint]).GetZ():
                    if voxelValue < (wavePointList[indexWavePoint]).GetValue():
                        voxelValue = (wavePointList[indexWavePoint]).GetValue()
                else:
                    saved = 1
                    CheckVoxel((wavePointList[indexWavePointRef]).GetX(), (wavePointList[indexWavePointRef]).GetY(), (wavePointList[indexWavePointRef]).GetZ(), voxelValue)
                    indexWavePointRef = indexWavePoint
                    if indexWavePointRef == sizeWavePointList - 1:
                        CheckVoxel((wavePointList[indexWavePointRef]).GetX(), (wavePointList[indexWavePointRef]).GetY(), (wavePointList[indexWavePointRef]).GetZ(), (wavePointList[indexWavePointRef]).GetValue())
                    break
            else:
                saved = 1
                CheckVoxel((wavePointList[indexWavePointRef]).GetX(), (wavePointList[indexWavePointRef]).GetY(), (wavePointList[indexWavePointRef]).GetZ(), voxelValue)
                columnObject = columnInstance.Column((wavePointList[indexWavePointRef]).GetX(), (wavePointList[indexWavePointRef]).GetY())
                columnManagement.AddColumn(columnObject)
                MaximumHeightColumn((wavePointList[indexWavePointRef]).GetX(), (wavePointList[indexWavePointRef]).GetY(), (wavePointList[indexWavePointRef]).GetZ()) 
                indexWavePointRef = indexWavePoint
                break
        if saved == 0:
            CheckVoxel((wavePointList[indexWavePointRef]).GetX(), (wavePointList[indexWavePointRef]).GetY(), (wavePointList[indexWavePointRef]).GetZ(), voxelValue)
            indexWavePointRef = indexWavePoint
    columnObject = columnInstance.Column((wavePointList[indexWavePointRef]).GetX(), (wavePointList[indexWavePointRef]).GetY())
    columnManagement.AddColumn(columnObject)
    MaximumHeightColumn((wavePointList[indexWavePointRef]).GetX(), (wavePointList[indexWavePointRef]).GetY(), (wavePointList[indexWavePointRef]).GetZ())



'''I check if the data stored in a voxel is lower than the new one; if its the case, I store it'''  

cdef CheckVoxel (double X, double Y, double Z, double newValue):
    cdef object bandVoxel, structvalCheckVoxel, out_str
    cdef tuple valueCheckVoxel
    
    bandVoxel = datasetVoxels.GetRasterBand(int(math.floor(Z/0.3))+1)
    structvalCheckVoxel = bandVoxel.ReadRaster(int(math.floor((X-Xmin)/0.25)), int(math.floor((Ymax-Y)/0.25)), 1, 1, buf_type=gdal.GDT_Float32)
    valueCheckVoxel = struct.unpack('f', structvalCheckVoxel)
    
    if newValue > valueCheckVoxel[0]:
        out_str = struct.pack('f', newValue)
        bandVoxel.WriteRaster(int(math.floor((X-Xmin)/0.25)), int(math.floor((Ymax-Y)/0.25)), 1, 1, out_str)

'''I check if this point has the highest Z and I store this information'''    
cdef MaximumHeightColumn(double X, double Y, double newZ):
        cdef object bandMetricMaximumHeightColumn, structvalMaximumHeightColumn, out_strMaximumHeightColumn
    cdef tuple valueMaximumHeightColumn

    bandMetricMaximumHeightColumn = datasetMetrics.GetRasterBand(10)
    structvalMaximumHeightColumn = bandMetricMaximumHeightColumn.ReadRaster(int(math.floor((X-Xmin)/0.25)), int(math.floor((Ymax-Y)/0.25)), 1, 1, buf_type=gdal.GDT_Float32)
    valueMaximumHeightColumn = struct.unpack('f', structvalMaximumHeightColumn)
    
    if newZ > round(valueMaximumHeightColumn[0], 1):
        out_strMaximumHeightColumn = struct.pack('f', newZ)
        bandMetricMaximumHeightColumn.WriteRaster(int(math.floor((X-Xmin)/0.25)), int(math.floor((Ymax-Y)/0.25)), 1, 1, out_strMaximumHeightColumn)

WavePointManagement.pyx

'''this class serializes, rounds and sorts the points of each ndarray'''
import cPickle as pickle
import numpy as np
cimport numpy as np
import math

cdef class WavePointManagement(object):
    '''
    This class manages all the points extracted from the waveform
    '''
    cdef object fileObject, wavePointList
    __slots__ = ('wavePointList', 'fileObject')

    def __cinit__(self):
        '''
        Constructor
        '''
        
        self.fileObject = None
        self.wavePointList = np.array([])

    cdef object GetWavePointList(self):
        return self.wavePointList

    cdef void OpenWavePointFileLoad (self, object fileName):
        self.fileObject = file(fileName, 'rb')

    cdef void LoadWavePointFile (self):
        self.wavePointList = None
        self.wavePointList = pickle.load(self.fileObject)
        
    cdef void SortWavePointList (self):
        self.wavePointList = sorted(self.wavePointList, key=lambda k: (k.x, k.y, k.z))

    cdef void RoundCoordinates (self):
        cdef int indexPointObject, sizeWavePointList
        
        for pointObject in self.GetWavePointList():
            pointObject.SetX(round(math.floor(pointObject.GetX()/0.25)*0.25, 2))
            pointObject.SetY(round(math.ceil(pointObject.GetY()/0.25)*0.25, 2))
            pointObject.SetZ(round(math.floor(pointObject.GetZ()/0.3)*0.3, 1))

    cdef void CloseWavePointFile(self):
        self.fileObject.close()

setup.py

from distutils.core import setup
from distutils.extension import Extension
from Cython.Distutils import build_ext

import numpy

ext = Extension("main", ["main.pyx"], include_dirs = [numpy.get_include()])

setup (ext_modules=[ext], 
       cmdclass = {'build_ext' : build_ext}
       )

test_cython.py

'''this is the file I run with eclipse after compiling'''
from main import main

main()

我怎样才能加快这段代码的速度?

您的代码在使用 numpy 数组和列表之间来回跳转。因此,cython 生成的代码几乎没有区别。

下面的代码生成一个python列表,关键函数也是一个纯python函数。

self.wavePointList = sorted(self.wavePointList, key=lambda k: (k.x, k.y, k.z))

您需要使用 ndarray.sort(如果您不想就地排序,则可以使用 numpy.sort)。为此,您还需要更改对象在数组中的存储方式。也就是说,您将需要使用 structured array. See numpy.sort 作为有关如何对结构化数组进行排序的示例——尤其是页面上的最后两个示例。

将数据存储在 numpy 数组中后,您需要告诉 cython 数据是如何存储在数组中的。这包括提供类型信息和数组的维度。 This page 提供了有关如何高效使用 numpy 数组的更多信息。

创建和排序结构化数组的示例:

import numpy as np
cimport numpy as np

DTYPE = [('name', 'S10'), ('height', np.float64), ('age', np.int32)]

cdef packed struct Person:
    char name[10]
    np.float64_t height
    np.int32_t age

ctypedef Person DTYPE_t

def create_array():
    values = [('Arthur', 1.8, 41), ('Lancelot', 1.9, 38),
              ('Galahad', 1.7, 38)]
    return np.array(values, dtype=DTYPE)

cpdef sort_by_age_then_height(np.ndarray[DTYPE_t, ndim=1] arr):
    arr.sort(order=['age', 'height'])  

最后,您需要将代码从使用 python 方法转换为使用标准 c 库方法以进一步加快速度。下面是一个使用 RoundCoordinates 的例子。 ``cpdef` 意味着该函数也通过包装函数公开给 python。

cimport cython
cimport numpy as np
from libc.math cimport floor, ceil, round

import numpy as np

DTYPE = [('x', np.float64), ('y', np.float64), ('z', np.float64)]

cdef packed struct Point3D:
    np.float64_t x, y, z

ctypedef Point3D DTYPE_t

# Caution should be used when turning the bounds check off as it can lead to undefined 
# behaviour if you use an invalid index.
@cython.boundscheck(False)
cpdef RoundCoordinates_cy(np.ndarray[DTYPE_t] pointlist):
    cdef int i
    cdef DTYPE_t point
    for i in range(len(pointlist)): # this line is optimised into a c loop
        point = pointlist[i] # creates a copy of the point
        point.x = round(floor(point.x/0.25)*2.5) / 10
        point.y = round(ceil(point.y/0.25)*2.5) / 10
        point.z = round(floor(point.z/0.3)*3) / 10
        pointlist[i] = point # overwrites the old point data with the new data

最后,在重写整个代码库之前,您应该分析代码以查看程序花​​费大部分时间的函数并在优化其他函数之前优化这些函数。