为什么 dgemm(Cython 编译)比 numpy.dot 慢

Why dgemm (Cython compiled) is slower than numpy.dot

长话短说,我在 Cython 中构建了一个简单的乘法函数,调用 scipy.linalg.cython_blas.dgemm,编译它并 运行 它针对基准 Numpy.dot。当我使用静态定义、数组维度预分配、内存视图、关闭检查等技巧时,我听说过 50% 到 100 倍的性能提升的神话。但后来我写了自己的 my_dot 函数(编译后),它比默认值 Numpy.dot4 倍。我真的不知道是什么原因,所以我只能有一些猜测:

1) BLAS 库未链接

2) 可能有一些我没有发现的内存开销

3) dot 正在使用一些隐藏的魔法

4) 写得不好 setup.py 并且 c 代码没有优化编译

5) 我的my_dot函数写得不高效

以下是我的代码片段和我能想到的所有相关信息,可能有助于解决这个难题。如果有人能提供一些关于我做错了什么的见解,或者如何将性能提高到至少与默认值相当 Numpy.dot

,我将不胜感激

文件 1:model_cython/multi.pyx。您还需要在文件夹中添加 model_cython/init.py

#cython: language_level=3 
#cython: boundscheck=False
#cython: nonecheck=False
#cython: wraparound=False
#cython: infertypes=True
#cython: initializedcheck=False
#cython: cdivision=True
#distutils: extra_compile_args = -Wno-unused-function -Wno-unneeded-internal-declaration


from scipy.linalg.cython_blas cimport dgemm
import numpy as np
from numpy cimport ndarray, float64_t
from numpy cimport PyArray_ZEROS
cimport numpy as np
cimport cython

np.import_array()
ctypedef float64_t DOUBLE

def my_dot(double [::1, :] a, double [::1, :] b, int ashape0, int ashape1, 
        int bshape0, int bshape1):
    cdef np.npy_intp cshape[2]
    cshape[0] = <np.npy_intp> ashape0
    cshape[1] = <np.npy_intp> bshape1

    cdef:
        int FORTRAN = 1
        ndarray[DOUBLE, ndim=2] c = PyArray_ZEROS(2, cshape, np.NPY_DOUBLE, FORTRAN)

    cdef double alpha = 1.0
    cdef double beta = 0.0
    dgemm("N", "N", &ashape0, &bshape1, &ashape1, &alpha, &a[0,0], &ashape0, &b[0,0], &bshape0, &beta, &c[0,0], &ashape0)
    return c

文件 2:model_cython/example.py。做基准测试的脚本

setup_str = """
import numpy as np
from numpy import float64
from multi import my_dot

a = np.ones((2,3), dtype=float64, order='F')
b = np.ones((3,2), dtype=float64, order='F')
print(a.flags)
ashape0, ashape1 = a.shape
bshape0, bshape1 = b.shape
"""
import timeit
print(timeit.timeit(stmt='c=my_dot(a,b, ashape0, ashape1, bshape0, bshape1)', setup=setup_str, number=100000))
print(timeit.timeit(stmt='c=a.dot(b)', setup=setup_str, number=100000))

文件 3:setup.py。编译 .so 文件

from distutils.core import setup, Extension
from Cython.Build import cythonize
from Cython.Distutils import build_ext
import numpy 
import os
basepath = os.path.dirname(os.path.realpath(__file__))
numpy_path = numpy.get_include()
package_name = 'multi'
setup(
        name='multi',
        cmdclass={'build_ext': build_ext},
        ext_modules=[Extension(package_name, 
            [os.path.join(basepath, 'model_cython', 'multi.pyx')], 
            include_dirs=[numpy_path],
            )],
        )

文件 4:run.sh。 Shell 执行 setup.py 并移动东西的脚本

python3 setup.py build_ext --inplace
path=$(pwd)
rm -r build
mv $path/multi.cpython-37m-darwin.so $path/model_cython/
rm $path/model_cython/multi.c

下面是编译消息的截图:

关于BLAS,我的Numpy/usr/local/lib处​​正确链接到它,clang -bundle似乎也在编译时添加了-L/usr/local/lib。但也许这还不够?

Cython 擅长优化循环(在 Python 中通常很慢),也是一种调用 C 的便捷方式(这是你想做的)。但是,从 Python 调用 Cython 函数可能相对较慢 - 特别是因为您指定的所有类型都需要检查一致性。因此,您通常会尝试在一次 Cython 调用后隐藏大量工作,以便开销很小。

您几乎选择了最坏的情况:大量呼叫背后的一小部分工作。 Cython 或 np.dot 是否会有更多开销是相当随意的,但无论哪种方式,你正在测量,而不是 np.dot 与 BLAS dgemm.

从您的评论看来,您实际上想要对两个 3D 数组的前两个维度进行点积。因此,一个更有用的测试是尝试重现它。这里有三个版本:

def einsum_mult(a,b):
    # use np.einsum, won't benefit from Cython
    return np.einsum("ijh,jkh->ikh",a,b)

def manual_mult(a,b):
    # multiply one matrix at a time with numpy dot
    # (could probably be optimized a bit with Cython)
    c = np.empty((a.shape[0],b.shape[1],a.shape[2]),
                 dtype=np.float64, order='F')
    for n in range(a.shape[2]):
        c[:,:,n] = a[:,:,n].dot(b[:,:,n])
    return c

def blas_version(double[::1,:,:] a,double[::1,:,:] b):
    # uses dgemm
    cdef double[::1,:,:] c = np.empty((a.shape[0], b.shape[1], a.shape[2]),
                                      dtype=np.float64, order='F')
    cdef double[::1,:] c_part
    cdef int n
    cdef double alpha = 1.0
    cdef double beta = 0.0
    cdef int ashape0 = a.shape[0], ashape1 = a.shape[1], bshape0 = b.shape[0], bshape1 = b.shape[1]

    assert a.shape[2]==b.shape[2]
    assert a.shape[1]==b.shape[0]

    for n in range(a.shape[2]):
        c_part = c[:,:,n]
        dgemm("N", "N", &ashape0, &bshape1, &ashape1, &alpha, &a[0,0,n], &ashape0, 
              &b[0,0,n], &bshape0, &beta, &c_part[0,0], &ashape0)
    return c

对于大小为 (2,3,10000)(3,2,10000) 的数组,重复 100 次,我得到:

manual_mult 1.6531286190001993 s    (i.e. quite bad)
einsum 0.3215398370011826 s         (pretty good)
blas_version 0.15762194800481666 s  (best, pretty close to the "myth" performance gain you mention)

如果您利用 Cython 并将循环保留在编译代码中,则 BLAS 版本速度很快。 (我没有花任何精力来优化它,所以如果你尝试的话你可能会打败它,但这只是为了说明这一点)