编写 cython 库时性能不佳

Poor performance in writing a cython library

test.pyx 是:

import gzip, re
import numpy as np
cimport numpy as np

cpdef np.ndarray[np.uint32_t, ndim=2] collect_qualities(str file_in, int length):
    '''
    What it Does:
    ------------
    bla bla bla

    Input:
    -----
    file_in: path-filename
    length: length of each sequence

    Output:
    ------
    numpy array with shape (n,m) where n=length of reads and m=solexa scores
    '''
 
    cdef str solexa_scores = '!"#$%&' + "'()*+,-./0123456789:;<=>?@ABCDEFGHI"

    cdef np.ndarray[np.uint32_t, ndim=1] N = np.zeros(shape=length, dtype=np.uint32) # This is the divisor of the mean
    cdef np.ndarray[np.uint32_t, ndim=2] sums = np.zeros(shape=(length, len(solexa_scores)+33), dtype=np.uint32) # This is the dividend of the mean
    
    cdef counter=0 # Useful to know if it's the 3rd or 4th line of the current sequence in fastq.
    with gzip.open(file_in, "rb") as f:
        for line in f:
        
            if counter%4==0: # first line of the sequence (obtain tail info)
                tile = line.decode('utf-8').split(':')[4]
                counter=0
        
            elif counter%3==0: # 3rd line of the sequence (obtain the qualities)
                for n, score in enumerate(line.decode('utf-8')):
                    sums[n, ord(score)] +=1
                    
            counter+=1
    return sums

setup.py 是:

from distutils.core import setup
from Cython.Build import cythonize
import numpy

setup(
        ext_modules = cythonize("test.pyx"),
        include_dirs=[numpy.get_include()]
)

它编译时带有关于已弃用的 numpy 的警告 API。

python.py 是:

import gzip, re
import numpy as np

def collect_qualities(file_in, length):
    '''
    What it Does:
    ------------
    bla bla bla

    Input:
    -----
    file_in: path-filename
    length: length of each sequence

    Output:
    ------
    numpy array with shape (n,m) where n=length of reads and m=solexa scores
    '''
 
    solexa_scores = '!"#$%&' + "'()*+,-./0123456789:;<=>?@ABCDEFGHI"
    
    N = np.zeros(shape=length, dtype=np.uint32) # This is the divisor of the mean
    sums = np.zeros(shape=(length, len(solexa_scores)+33)) # This is the dividend of the mean
    
    counter=0 # Useful to know if it's the 3rd or 4th line of the current sequence in fastq.
    with gzip.open(file_in, "rb") as f:
        for line in f:
        
            if counter%4==0: # first line of the sequence (obtain tail info)
                tile = line.decode('utf-8').split(':')[4]
                counter=0
        
            elif counter%3==0: # 3rd line of the sequence (obtain the qualities)
                for n, score in enumerate(line.decode('utf-8')):
                    sums[n, ord(score)] +=1
                    
            counter+=1
    return sums

然后我在 ipython 中导入函数并比较它们的 运行 次。

对于相当小的输入文件,python 需要大约 140 秒,而 cython 编译需要大约 950 秒。

我在 cython 中做错了什么?

谢谢!

正如我在评论中所说:我真的不相信你的时间安排,并认为这里可能还有其他事情发生(例如,你无意中使用了你安装在某个地方的旧版本,它做了一些不同的事情) .我希望 Cython 在这里稍微快一些,但不会显着。我也不可能用提供的信息实际测试我的任何建议。

但是,一些建议:

首先,你做的很多事情都是毫无意义的,或者是小小的悲观情绪。

  1. 使用 cpdef 函数没有意义,因为 defcpdef/cdef 函数的内部编译方式相同。使用 c[p]def 可以让您从 C/Cython 稍微更快地调用函数,这对于经常调用的小函数很有价值。我怀疑这是否适用于此。

  2. 指定 return 类型可能毫无意义。

  3. 打字 file_in 比不打字稍差 - Cython 无法优化它,因为它只将它传递给 open 函数,所以可能只是浪费时间type-check.

然后错过了一些优化机会:

  1. counter 应键入为 intcdef counter 只是使其成为通用 Python 对象)。

  2. 可能值得将 line 作为 str 输入(您需要在函数顶部而不是 line 时执行此操作用于 for-loop).

  3. 可能值得将中间 decoded_line 键入为 bytes(例如 decoded_line = line.decode(...)),因为这是要迭代的内容。 Cython 可能会从 str.decode 中自行推断出这一点,但最好确定一下。

  4. 在 Cython 中,在一个范围内直接迭代通常比使用 enumerate 这样的东西更好。 (这与 Python 不同)。做 for n in range(len(decoded_line)): score = decoded_line[n]。 Cython 可能会自己进行此优化,但请务必自己进行。

  5. 可能值得使用 compiler directives 关闭边界检查和环绕。我的建议是尽可能在本地执行此操作(即不要只是以 cargo-cult 方式将每个函数与它们一起包装,而是考虑它在哪里有用以及它是否安全)。

  6. sums 在您的 Python 和 Cython 版本中有不同的 dtypeuint32 在 Cython 中,double 在 Python).想想哪个是对的。

  7. 使用 cython -a 获取带注释的 html 版本的函数,突出显示未优化的位。担心突出显示的重要循环 - 不要过于关注只调用一次的东西。