避免 Numpy 中的 "complex division by zero" 错误(与 cmath 相关)

Avoiding "complex division by zero" error in Numpy (related to cmath)

我想在我的计算中避免 ZeroDivisionError: complex division by zero,在该异常处得到 nan

让我用一个简单的例子来说明我的问题:

from numpy import *
from cmath import *

def f(x) :
    x = float_(x) 
    return 1./(1.-x)

def g(x) :
    x = float_(x)
    return sqrt(-1.)/(1.-x)

f(1.)   # This gives 'inf'.
g(1.)   # This gives 'ZeroDivisionError: complex division by zero'.

我打算得到 g(1.) = nan,或者至少是除了中断计算的错误之外的任何东西。 第一个问题:我该怎么做?

重要的是,我不想修改函数内部的代码(例如,插入异常条件,如中所做的那样),而是保持当前形式(或者如果可能,甚至删除 x = float_(x) 行,正如我在下面提到的)。 原因是我正在使用包含许多函数的长代码:我希望它们都避免 ZeroDivisionError 而无需进行大量更改。

我被迫插入 x = float_(x) 以避免 f(1.) 中出现 ZeroDivisionError第二个问题: 是否有一种方法可以抑制这条线,但仍然得到 f(1.) = inf 而无需修改定义 f 的所有代码?


编辑:

我意识到使用 cmath (from cmath import *) 是造成错误的原因。 没有它,我得到 g(1.) = nan,这就是我想要的。 但是,我的代码中需要它。 所以现在第一个问题变成了下面的问题:如何在使用cmath时避免"complex division by zero"?


编辑 2:

看完答案后,我做了一些修改,我简化了问题,更接近重点:

import numpy as np                            
import cmath as cm                            

def g(x) :                                    
    x = np.float_(x)                         
    return cm.sqrt(x+1.)/(x-1.)    # I want 'g' to be defined in R-{1}, 
                                   # so I have to use 'cm.sqrt'. 

print 'g(1.) =', g(1.)             # This gives 'ZeroDivisionError: 
                                   # complex division by zero'.

问题: 如何避免 ZeroDivisionError 不修改我的函数 g 的代码?

我希望该函数自己处理错误,但如果没有,那么您可以在调用该函数时执行此操作(不过我相信这比修改函数更有效)。

try:
    f(1.)   # This gives 'inf'.
except ZeroDifivisionError:
    None    #Or whatever
try:
    g(1.)   # This gives 'ZeroDivisionError: complex division by zero'.
except ZeroDifivisionError:
    None    #Or whatever

或者,正如您提到的答案所说,用一个 numpy 数字调用您的函数:

f(np.float_(1.))
g(np.float_(1.))

=======
在我的机器上,这个确切的代码给我一个警告并且没有错误。 我不认为在不捕获错误的情况下得到你想要的东西是不可能的...

def f(x):
    return 1./(1.-x)

def g(x):
    return np.sqrt(-1.)/(1.-x)

print f(np.float_(1.))
>> __main__:2: RuntimeWarning: divide by zero encountered in double_scalars
>>inf

print g(np.float_(1.))
>> __main__:5: RuntimeWarning: invalid value encountered in sqrt
>>nan

出现错误是因为您的程序正在使用 cmath 库中的 sqrt 函数。这是一个很好的例子,说明为什么您应该避免(或至少要小心)使用 from library import *.

导入整个库

例如,如果您反转导入语句,那么您将没有问题:

from cmath import *
from numpy import *

def f(x):
    x = float_(x) 
    return 1./(1.-x)


def g(x) :
    x = float_(x)
    return sqrt(-1.)/(1.-x)

现在g(1.)returnsnan,因为用到了numpysqrt函数(可以处理负数)。但这仍然是不好的做法:如果你有一个大文件,那么不清楚正在使用哪个 sqrt

我建议始终使用命名导入,例如 import numpy as npimport cmath as cm。那么函数sqrt没有被导入,要定义g(x)需要写成np.sqrt(或者cm.sqrt).

但是,您注意到您不想更改您的函数。在那种情况下,你应该只从你需要的每个库中导入那些函数,即

from cmath import sin  # plus whatever functions you are using in that file
from numpy import sqrt, float_

def f(x):
    x = float_(x) 
    return 1./(1.-x)


def g(x) :
    x = float_(x)
    return sqrt(-1.)/(1.-x)

不幸的是,您不能轻易摆脱到 numpy 的转换 float_,因为 python 浮点数与 numpy 浮点数不同,所以这种转换需要在某个时候进行。

我还是不明白你为什么要用cmath。 Type-cast x 变成 np.complex_ 当你期望复杂输出时,然后使用 np.sqrt.

import numpy as np

def f(x):
    x = np.float_(x)
    return 1. / (1. - x)

def g(x):
    x = np.complex_(x)
    return np.sqrt(x + 1.) / (x - 1.)

这产生:

>>> f(1.)
/usr/local/bin/ipython3:3: RuntimeWarning: divide by zero encountered in double_scalars
  # -*- coding: utf-8 -*-
Out[131]: inf

>>> g(-3.)
Out[132]: -0.35355339059327379j

>>> g(1.)
/usr/local/bin/ipython3:3: RuntimeWarning: divide by zero encountered in cdouble_scalars
  # -*- coding: utf-8 -*-
/usr/local/bin/ipython3:3: RuntimeWarning: invalid value encountered in cdouble_scalars
  # -*- coding: utf-8 -*-
Out[133]: (inf+nan*j)

缺点是,当然,函数 g 总是 最终给你复杂的输出,如果你反馈它的结果,这可能会在以后引起问题例如,进入 f,因为 that 现在 type-casts 浮动,等等......也许你应该只是 type-cast 到复杂的地方。但这将取决于您需要在更大范围内实现什么。

编辑

事实证明,只有在需要时,才有办法使 g 到 return 复杂。使用 numpy.emath.

import numpy as np

def f(x):
    x = np.float_(x)
    return 1. / (1. - x)

def g(x):
    x = np.float_(x)
    return np.emath.sqrt(x + 1.) / (x - 1.)

这现在给出了您所期望的,仅在必要时才转换为复数。

>>> f(1)
/usr/local/bin/ipython3:8: RuntimeWarning: divide by zero encountered in double_scalars
Out[1]: inf

>>> g(1)
/usr/local/bin/ipython3:12: RuntimeWarning: divide by zero encountered in double_scalars
Out[2]: inf

>>> g(-3)
Out[3]: -0.35355339059327379j