Cython 中的复数

Complex numbers in Cython

在 Cython 中处理复数的正确方法是什么?

我想使用 dtype np.complex128 的 numpy.ndarray 编写纯 C 循环。在 Cython 中,关联的 C 类型定义在 Cython/Includes/numpy/__init__.pxd 作为

ctypedef double complex complex128_t

看来这只是一个简单的 C 双复数。

然而,很容易获得奇怪的行为。特别是,使用这些定义

cimport numpy as np
import numpy as np
np.import_array()

cdef extern from "complex.h":
    pass

cdef:
    np.complex128_t varc128 = 1j
    np.float64_t varf64 = 1.
    double complex vardc = 1j
    double vard = 1.

线

varc128 = varc128 * varf64

可以用 Cython 编译,但 gcc 不能编译生成的 C 代码(错误是 "testcplx.c:663:25: error: two or more data types in declaration specifiers",似乎是由于行 typedef npy_float64 _Complex __pyx_t_npy_float64_complex;)。这个错误已经被报告了(例如here)但是我没有找到任何好的解释and/or干净的解决方案。

如果不包含 complex.h,则不会出现错误(我猜是因为 typedef 不包含在内)。

但是,还有一个问题,因为在cython -a testcplx.pyx生成的html文件中,varc128 = varc128 * varf64行是黄色的,这意味着它还没有被翻译成纯C语言。对应的C代码为:

__pyx_t_2 = __Pyx_c_prod_npy_float64(__pyx_t_npy_float64_complex_from_parts(__Pyx_CREAL(__pyx_v_8testcplx_varc128), __Pyx_CIMAG(__pyx_v_8testcplx_varc128)), __pyx_t_npy_float64_complex_from_parts(__pyx_v_8testcplx_varf64, 0));
__pyx_v_8testcplx_varc128 = __pyx_t_double_complex_from_parts(__Pyx_CREAL(__pyx_t_2), __Pyx_CIMAG(__pyx_t_2));

并且 __Pyx_CREAL__Pyx_CIMAG 是橙色的(Python 调用)。

有趣的是,行

vardc = vardc * vard

不会产生任何错误并被翻译成纯 C(只是 __pyx_v_8testcplx_vardc = __Pyx_c_prod(__pyx_v_8testcplx_vardc, __pyx_t_double_complex_from_parts(__pyx_v_8testcplx_vard, 0));),而它与第一个非常相似。

我可以通过使用中间变量来避免错误(并且它转化为纯 C):

vardc = varc128
vard = varf64
varc128 = vardc * vard

或简单地通过强制转换(但不会转换为纯 C):

vardc = <double complex>varc128 * <double>varf64

那么会发生什么?编译错误是什么意思?有没有一种干净的方法可以避免它?为什么 np.complex128_t 和 np.float64_t 的乘法似乎涉及 Python 调用?

版本

Cython 版本 0.22(提出问题时 Pypi 中的最新版本)和 GCC 4.9.2。

存储库

我创建了一个带有示例 (hg clone https://bitbucket.org/paugier/test_cython_complex) 的小型存储库和一个带有 3 个目标 (make cleanmake buildmake html) 的小型 Makefile,因此很容易测试任何东西。

我能找到解决此问题的最简单方法是简单地切换乘法顺序。

如果在testcplx.pyx我改

varc128 = varc128 * varf64

varc128 = varf64 * varc128

我从描述的失败情况更改为正常工作的情况。这种情况很有用,因为它允许直接比较生成的 C 代码。

tl;博士

乘法的顺序改变了翻译,这意味着在失败版本中乘法是通过 __pyx_t_npy_float64_complex 类型尝试的,而在工作版本中它是通过 __pyx_t_double_complex 类型完成的。这又引入了无效的 typedef 行 typedef npy_float64 _Complex __pyx_t_npy_float64_complex;

我相当确定这是一个 cython 错误(更新:reported here). Although this is a very old gcc bug report,响应明确指出(说它实际上不是 gcc bug, 但用户代码错误):

typedef R _Complex C;

This is not valid code; you can't use _Complex together with a typedef, only together with "float", "double" or "long double" in one of the forms listed in C99.

他们得出结论,double _Complex 是一个有效的类型说明符,而 ArbitraryType _Complex 不是。 This more recent report has the same type of response - trying to use _Complex on a non fundamental type is outside spec, and the GCC manual表示_Complex只能与floatdoublelong double

一起使用

所以 - 我们可以破解 cython 生成的 C 代码来测试:用 typedef double _Complex __pyx_t_npy_float64_complex; 替换 typedef npy_float64 _Complex __pyx_t_npy_float64_complex; 并验证它确实有效并且可以使输出代码编译。


代码的短途跋涉

交换乘法顺序只会突出编译器告诉我们的问题。在第一种情况下,有问题的行是 typedef npy_float64 _Complex __pyx_t_npy_float64_complex; - 它试图分配类型 npy_float64 and 使用关键字 _Complex 到类型 __pyx_t_npy_float64_complex.

float _Complexdouble _Complex 是有效类型,而 npy_float64 _Complex 不是。要查看效果,您可以从该行删除 npy_float64,或将其替换为 doublefloat,代码编译正常。下一个问题是为什么首先要生产那条线...

这似乎是由 this line 在 Cython 源代码中产生的。

为什么乘法的顺序会显着改变代码 - 以至于引入类型 __pyx_t_npy_float64_complex,并且以失败的方式引入?

在失败的实例中,实现乘法的代码将 varf64 转换为 __pyx_t_npy_float64_complex 类型,对实部和虚部进行乘法,然后重新组合复数。在工作版本中,它使用函数 __Pyx_c_prod

通过 __pyx_t_double_complex 类型直接生成产品

我想这就像 cython 代码从遇到的第一个变量中获取用于乘法的提示一样简单。在第一种情况下,它看到一个浮点数 64,因此基于它生成 (invalid) C 代码,而在第二种情况下,它看到 (double) complex128 类型并将其转换基于那。这个解释有点曲折,如果时间允许,我希望return分析一下...

关于此的注释 - here we see that the typedef for npy_float64 is double, so in this particular case, a fix might consist of modifying the code here 使用 double _Complex 其中 typenpy_float64,但这超出了 SO 答案的范围并且没有提出一个通用的解决方案。


C 代码差异结果

工作版本

从行 `varc128 = varf64 * varc128

创建此 C 代码
__pyx_v_8testcplx_varc128 = __Pyx_c_prod(__pyx_t_double_complex_from_parts(__pyx_v_8testcplx_varf64, 0), __pyx_v_8testcplx_varc128);

版本失败

varc128 = varc128 * varf64

行创建此 C 代码
__pyx_t_2 = __Pyx_c_prod_npy_float64(__pyx_t_npy_float64_complex_from_parts(__Pyx_CREAL(__pyx_v_8testcplx_varc128), __Pyx_CIMAG(__pyx_v_8testcplx_varc128)), __pyx_t_npy_float64_complex_from_parts(__pyx_v_8testcplx_varf64, 0));
  __pyx_v_8testcplx_varc128 = __pyx_t_double_complex_from_parts(__Pyx_CREAL(__pyx_t_2), __Pyx_CIMAG(__pyx_t_2));

这需要这些额外的导入 - 而有问题的行是 typedef npy_float64 _Complex __pyx_t_npy_float64_complex; - 它试图分配类型 npy_float64 and类型_Complex到类型__pyx_t_npy_float64_complex

#if CYTHON_CCOMPLEX
  #ifdef __cplusplus
    typedef ::std::complex< npy_float64 > __pyx_t_npy_float64_complex;
  #else
    typedef npy_float64 _Complex __pyx_t_npy_float64_complex;
  #endif
#else
    typedef struct { npy_float64 real, imag; } __pyx_t_npy_float64_complex;
#endif

/*... loads of other stuff the same ... */

static CYTHON_INLINE __pyx_t_npy_float64_complex __pyx_t_npy_float64_complex_from_parts(npy_float64, npy_float64);

#if CYTHON_CCOMPLEX
    #define __Pyx_c_eq_npy_float64(a, b)   ((a)==(b))
    #define __Pyx_c_sum_npy_float64(a, b)  ((a)+(b))
    #define __Pyx_c_diff_npy_float64(a, b) ((a)-(b))
    #define __Pyx_c_prod_npy_float64(a, b) ((a)*(b))
    #define __Pyx_c_quot_npy_float64(a, b) ((a)/(b))
    #define __Pyx_c_neg_npy_float64(a)     (-(a))
  #ifdef __cplusplus
    #define __Pyx_c_is_zero_npy_float64(z) ((z)==(npy_float64)0)
    #define __Pyx_c_conj_npy_float64(z)    (::std::conj(z))
    #if 1
        #define __Pyx_c_abs_npy_float64(z)     (::std::abs(z))
        #define __Pyx_c_pow_npy_float64(a, b)  (::std::pow(a, b))
    #endif
  #else
    #define __Pyx_c_is_zero_npy_float64(z) ((z)==0)
    #define __Pyx_c_conj_npy_float64(z)    (conj_npy_float64(z))
    #if 1
        #define __Pyx_c_abs_npy_float64(z)     (cabs_npy_float64(z))
        #define __Pyx_c_pow_npy_float64(a, b)  (cpow_npy_float64(a, b))
    #endif
 #endif
#else
    static CYTHON_INLINE int __Pyx_c_eq_npy_float64(__pyx_t_npy_float64_complex, __pyx_t_npy_float64_complex);
    static CYTHON_INLINE __pyx_t_npy_float64_complex __Pyx_c_sum_npy_float64(__pyx_t_npy_float64_complex, __pyx_t_npy_float64_complex);
    static CYTHON_INLINE __pyx_t_npy_float64_complex __Pyx_c_diff_npy_float64(__pyx_t_npy_float64_complex, __pyx_t_npy_float64_complex);
    static CYTHON_INLINE __pyx_t_npy_float64_complex __Pyx_c_prod_npy_float64(__pyx_t_npy_float64_complex, __pyx_t_npy_float64_complex);
    static CYTHON_INLINE __pyx_t_npy_float64_complex __Pyx_c_quot_npy_float64(__pyx_t_npy_float64_complex, __pyx_t_npy_float64_complex);
    static CYTHON_INLINE __pyx_t_npy_float64_complex __Pyx_c_neg_npy_float64(__pyx_t_npy_float64_complex);
    static CYTHON_INLINE int __Pyx_c_is_zero_npy_float64(__pyx_t_npy_float64_complex);
    static CYTHON_INLINE __pyx_t_npy_float64_complex __Pyx_c_conj_npy_float64(__pyx_t_npy_float64_complex);
    #if 1
        static CYTHON_INLINE npy_float64 __Pyx_c_abs_npy_float64(__pyx_t_npy_float64_complex);
        static CYTHON_INLINE __pyx_t_npy_float64_complex __Pyx_c_pow_npy_float64(__pyx_t_npy_float64_complex, __pyx_t_npy_float64_complex);
    #endif
#endif