Numpy 分段意外结果

Numpy piecewise unexpected results

我正在尝试使用 numpy.piecewise:

定义一个非连续函数
import numpy as np

var = np.array([0.2, 10])
supp = lambda x: -np.sqrt(1 - 1/x**2)
inf = lambda x: 1j*np.sqrt(1/x**2 - 1)

def q1(eta_B):
    return np.piecewise(eta_B, [eta_B > 1, eta_B <= 1],
                        [supp, inf])

使用这段代码,这是我得到的:

>>> supp(var)
array([        nan, -0.99498744])
>>> inf(var)
array([ 0.+4.89897949j, nan       +nanj])
>>> q1(var)
array([ 0.        , -0.99498744])

我不明白这里发生了什么;我预计 q1(var) 到 return array([0.+4.89897949j , -0.99498744]) !

有什么提示吗?

当你 运行 q1(var) 你没有收到以下警告吗?

Warning (from warnings module):
  File "C:\Users\Chinchi\AppData\Roaming\Python\Python38\site-packages\numpy\lib\function_base.py", line 616
    y[cond] = func(vals, *args, **kw)
ComplexWarning: Casting complex values to real discards the imaginary part

查看 np.piecewise 的文档,我们看到

The output is the same shape and type as x

因为您从未为 var 定义类型,并且它的所有值都在 float 范围内,所以这就是它获得的类型。

>>> var.dtype
dtype('float64')
>>> var = np.array([0.2+0j, 10])
>>> var.dtype
dtype('complex128')
>>> q1(var)
array([ 0.        +4.89897949j, -0.99498744-0.j        ])

解决此问题的最佳方法是在函数内将输入转换为复杂类型,这样无论您是否通过 int/float/complex 都可以工作,而不必自己担心转换。

np.piecewise(eta_B.astype(np.complex128), [eta_B > 1, eta_B <= 1], [supp, inf])