Link 针对 NumPy 中的 BLAS 的 Cython 包装的 C 函数

Link Cython-wrapped C functions against BLAS from NumPy

我想在 Cython 扩展中使用一些在使用 BLAS 子例程的 .c 文件中定义的 C 函数,例如

cfile.c

double ddot(int *N, double *DX, int *INCX, double *DY, int *INCY);

double call_ddot(double* a, double* b, int n){
    int one = 1;
    return ddot(&n, a, &one, b, &one);
}

(假设这些函数不仅仅调用一个 BLAS 子例程)

pyfile.pyx

cimport numpy as np
import numpy as np

cdef extern from "cfile.c":
    double call_ddot(double* a, double* b, int n)

def pyfun(np.ndarray[double, ndim=1] a):
    return call_ddot(&a[0], &a[0], <int> a.shape[0])

setup.py:

from distutils.core import setup
from distutils.extension import Extension
from Cython.Build import cythonize
from Cython.Distutils import build_ext
import numpy

setup(
    name  = "wrapped_cfun",
    packages = ["wrapped_cfun"],
    cmdclass = {'build_ext': build_ext},
    ext_modules = [Extension("wrapped_cfun.cython_part", sources=["pyfile.pyx"], include_dirs=[numpy.get_include()])]
)

我希望此包 link 针对已安装的 NumPy 或 SciPy 使用的相同 BLAS 库,并希望它可以在不同操作系统下使用 numpy 或 scipy 作为依赖项,没有任何额外的 BLAS 相关依赖项。

setup.py 是否有任何 hack 可以让我以一种可以与任何 BLAS 实现一起工作的方式来完成这个?

更新: 使用 MKL,我可以通过修改 Extension 对象使其指向 libmkl_rt 来使其工作,如果安装了 MKL,则可以从 numpy 中提取它,例如: Extension("wrapped_cfun.cython_part", sources=["pyfile.pyx"], include_dirs=[numpy.get_include()], extra_link_args=["-L{path to python's lib dir}", "-l:libmkl_rt.{so, dll, dylib}"]) 然而,同样的技巧对 OpenBLAS 不起作用(例如 -l:libopenblasp-r0.2.20.so)。如果该文件是 libopenblas 的 link,则指向 libblas.{so,dll,dylib} 将不起作用,但工作正常它是 link 到 libmkl_rt.

更新二: OpenBLAS 似乎在末尾用下划线命名他们的 C 函数,例如不是 ddot,而是 ddot_。如果我在 .c 文件中将 ddot 更改为 ddot_,则上面带有 l:libopenblas 的代码将起作用。我仍然想知道是否有某种(理想情况下 运行 时间)机制来检测应在 c 文件中使用哪个名称。

我终于想出了一个丑陋的 hack 来解决这个问题。我不确定它是否会一直有效,但至少它适用于 Windows(mingw 和 visual studio)、Linux、MKL 和 OpenBlas 的组合。我仍然想知道是否有更好的选择,但如果没有,就这样做:

编辑: 现在更正 visual studio

  1. 修改 C 文件以考虑带下划线的名称(对调用的每个 BLAS 函数执行此操作)- 需要对每个函数声明两次并为每个函数添加一个 if

    double ddot_(int *N, double *DX, int *INCX, double *DY, int *INCY); #define ddot(N, DX, INCX, DY, INCY) ddot_(N, DX, INCX, DY, INCY)

    daxpy_(int *N, double *DA, double *DX, int *INCX, double *DY, int *INCY); #define daxpy(N, DA, DX, INCX, DY, INCY) daxpy_(N, DA, DX, INCX, DY, INCY)

    ...等等

  2. 从 NumPy 或 SciPy 中提取库路径并将其添加到 link 参数中。

  3. 检测要使用的编译器是否为visual studio,在这种情况下linking参数完全不同。

setup.py

from distutils.core import setup
from distutils.extension import Extension
from Cython.Build import cythonize
from Cython.Distutils import build_ext
import numpy
from sys import platform
import os

try:
    blas_path = numpy.distutils.system_info.get_info('blas')['library_dirs'][0]
except:
    if "library_dirs" in numpy.__config__.blas_mkl_info:
        blas_path = numpy.__config__.blas_mkl_info["library_dirs"][0]
    elif "library_dirs" in numpy.__config__.blas_opt_info:
        blas_path = numpy.__config__.blas_opt_info["library_dirs"][0]
    else:
        raise ValueError("Could not locate BLAS library.")
        

if platform[:3] == "win":
    if os.path.exists(os.path.join(blas_path, "mkl_rt.lib")):
        blas_file = "mkl_rt.lib"
    elif os.path.exists(os.path.join(blas_path, "mkl_rt.dll")):
        blas_file = "mkl_rt.dll"
    else:
        import re
        blas_file = [f for f in os.listdir(blas_path) if bool(re.search("blas", f))]
        if len(blas_file) == 0:
            raise ValueError("Could not locate BLAS library.")
        blas_file = blas_file[0]
        
elif platform[:3] == "dar":
    blas_file = "libblas.dylib"
else:
    blas_file = "libblas.so"

## 
class build_ext_subclass( build_ext ):
    def build_extensions(self):
        compiler = self.compiler.compiler_type
        if compiler == 'msvc': # visual studio
            for e in self.extensions:
                e.extra_link_args += [os.path.join(blas_path, blas_file)]
        else: # gcc
            for e in self.extensions:
                e.extra_link_args += ["-L"+blas_path, "-l:"+blas_file]
        build_ext.build_extensions(self)


setup(
    name  = "wrapped_cfun",
    packages = ["wrapped_cfun"],
    cmdclass = {'build_ext': build_ext_subclass},
    ext_modules = [Extension("wrapped_cfun.cython_part", sources=["pyfile.pyx"], include_dirs=[numpy.get_include()], extra_link_args=[])]
    )

依赖 linker/loader 提供正确的 blas 功能的替代方法是模拟必要 blas 符号的解析(例如 ddot)并使用包装的 blas-function provided by scipy 在运行期间。

不确定,这种方法优于 "normal way" 构建,但想引起您的注意,即使只是因为我觉得这种方法很有趣。

简而言之:

  1. 定义一个指向 ddot 功能的显式函数指针,在下面的代码片段中称为 my_ddot
  2. 在您将使用 ddot 的地方使用 my_ddot-指针,否则。
  3. 当 cython 模块加载 scipy 提供的功能时初始化 my_ddot-指针。

这是一个工作原型(我使用 C-code-verbatim 使代码段独立并且可以在 jupiter-notebook 中轻松测试,相信您可以将其转换为您的格式 need/like):

%%cython
# h-file:
cdef extern from *:
    """
    // blas-functionality,
    // will be initialized by cython when module is loaded:
    typedef double (*ddot_t)(int *N, double *DX, int *INCX, double *DY, int *INCY);
    extern ddot_t my_ddot;

    double call_ddot(double* a, double* b, int n);
    """
    ctypedef double (*ddot_t)(int *N, double *DX, int *INCX, double *DY, int *INCY)
    ddot_t my_ddot
    double call_ddot(double* a, double* b, int n)    

# init the functions of the c-library
# with blas-function provided by scipy
from scipy.linalg.cython_blas cimport ddot
my_ddot=ddot

# a simple function to demonstrate, that it works
def ddot_mult(double[:]a, double[:]b):
    cdef int n=len(a)
    return call_ddot(&a[0], &b[0], n)

#-------------------------------------------------
# c-file, added so the example is complete    
cdef extern from *:
    """  
    ddot_t my_ddot;
    double call_ddot(double* a, double* b, int n){
        int one = 1;
        return my_ddot(&n, a, &one, b, &one);
    }
    """
    pass

现在可以使用ddot_mult

import numpy as np
a=np.arange(4, dtype=float)

ddot_mult(a,a)  # 14.0 as expected!

这种方法的一个优点是,distutils 没有喧嚣,您可以保证使用与 scipy.

相同的 blas 功能

另一个好处:可以在运行时切换使用的引擎(mkl、open_blas 甚至是自己的实现),而无需 recompile/relink。

另一方面,还有一些额外的样板代码和危险,即某些符号的初始化将被遗忘。

作为较新的 Cython 版本的另一种选择,可以创建一个“public”Cython 函数(它将可用于 C 代码和 auto-generate public header) 将简单地调用相应的 BLAS 函数:

from scipy.linalg.cython_blas cimport ddot
cdef public double ddot_(int *n, double *x, int *ldx, double *y, int *ldy):
    return ddot(n, x, ldx, y, ldy)

然后只需在 C 代码中声明它或包含 header,Cython 扩展构建器的其余部分将负责链接:

extern double ddot_(int *n, double *x, int *ldx, double *y, int *ldy);