关于python:在Numpy中避免“复杂除零”错误(与cmath相关)

Avoiding “complex division by zero” error in Numpy (related to cmath)

我想在我的计算中避免ZeroDivisionError: complex division by zero,在该异常处得到nan

让我用一个简单的例子来陈述我的问题:

1
2
3
4
5
6
7
8
9
10
11
12
13
from numpy import *
from cmath import *

def f(x) :
    x = float_(x)
    return 1./(1.-x)

def g(x) :
    x = float_(x)
    return sqrt(-1.)/(1.-x)

f(1.)   # This gives 'inf'.
g(1.)   # This gives 'ZeroDivisionError: complex division by zero'.

我打算得到g(1.) = nan,或者至少是一个中断计算的错误。
第一个问题:我该怎么办?

重要的是,我不想修改函数内部的代码(例如,插入异常的条件,如本答案中所做的那样),而是保持其当前形式(如果可能,甚至删除x = float_(x)行,正如我在下面提到的)。
原因是我正在使用包含许多函数的长代码:我希望它们都能避免ZeroDivisionError,而无需进行大量更改。

我被迫插入x = float_(x)以避免f(1.)中的ZeroDivisionError
第二个问题:是否有一种方法可以抑制这一行,但仍然得到f(1.) = inf而不修改定义f的所有代码?

编辑:

我已经意识到使用cmath(from cmath import *)导致错误。
没有它,我得到g(1.) = nan,这就是我想要的。
但是,我需要在我的代码中使用它。
所以现在第一个问题变成了以下问题:当使用cmath时,如何避免"复杂除零"?

编辑2:

在阅读答案后,我做了一些修改,并简化了问题,越来越接近这一点:

1
2
3
4
5
6
7
8
9
10
import numpy as np                            
import cmath as cm                            

def g(x) :                                    
    x = np.float_(x)                        
    return cm.sqrt(x+1.)/(x-1.)    # I want 'g' to be defined in R-{1},
                                   # so I have to use 'cm.sqrt'.

print 'g(1.) =', g(1.)             # This gives 'ZeroDivisionError:
                                   # complex division by zero'.

问题:如何避免ZeroDivisionError不修改我的函数g的代码?


我仍然不明白为什么你需要使用cmath。当您期望复杂输出时,将x输入x,然后使用np.sqrt

1
2
3
4
5
6
7
8
9
import numpy as np

def f(x):
    x = np.float_(x)
    return 1. / (1. - x)

def g(x):
    x = np.complex_(x)
    return np.sqrt(x + 1.) / (x - 1.)

这会产生:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
>>> f(1.)
/usr/local/bin/ipython3:3: RuntimeWarning: divide by zero encountered in double_scalars
  # -*- coding: utf-8 -*-
Out[131]: inf

>>> g(-3.)
Out[132]: -0.35355339059327379j

>>> g(1.)
/usr/local/bin/ipython3:3: RuntimeWarning: divide by zero encountered in cdouble_scalars
  # -*- coding: utf-8 -*-
/usr/local/bin/ipython3:3: RuntimeWarning: invalid value encountered in cdouble_scalars
  # -*- coding: utf-8 -*-
Out[133]: (inf+nan*j)

当然,缺点是函数g总是会给你带来复杂的输出,如果你把它的结果反馈回f,这可能会导致问题,因为现在这个类型转换为float,等等......也许你应该只是在任何地方打字复杂。但这取决于你需要在更大规模上实现的目标。

编辑

事实证明,只有在需要时才能让g返回复合体。使用numpy.emath

1
2
3
4
5
6
7
8
9
import numpy as np

def f(x):
    x = np.float_(x)
    return 1. / (1. - x)

def g(x):
    x = np.float_(x)
    return np.emath.sqrt(x + 1.) / (x - 1.)

现在,这将提供您所期望的,仅在必要时转换为复合体。

1
2
3
4
5
6
7
8
9
10
>>> f(1)
/usr/local/bin/ipython3:8: RuntimeWarning: divide by zero encountered in double_scalars
Out[1]: inf

>>> g(1)
/usr/local/bin/ipython3:12: RuntimeWarning: divide by zero encountered in double_scalars
Out[2]: inf

>>> g(-3)
Out[3]: -0.35355339059327379j

我希望函数能够自己处理错误,但如果没有,那么你可以在调用函数时执行它(我确实相信这比修改函数更多的工作)。

1
2
3
4
5
6
7
8
try:
    f(1.)   # This gives 'inf'.
except ZeroDifivisionError:
    None    #Or whatever
try:
    g(1.)   # This gives 'ZeroDivisionError: complex division by zero'.
except ZeroDifivisionError:
    None    #Or whatever

或者,正如您所提到的答案所述,请使用numpy数字调用您的函数:

1
2
f(np.float_(1.))
g(np.float_(1.))

=======
在我的机器上,这个确切的代码给了我一个警告,没有错误。
我不认为有可能得到你想要的东西而不会发现错误......

1
2
3
4
5
6
7
8
9
10
11
12
13
def f(x):
    return 1./(1.-x)

def g(x):
    return np.sqrt(-1.)/(1.-x)

print f(np.float_(1.))
>> __main__:2: RuntimeWarning: divide by zero encountered in double_scalars
>>inf

print g(np.float_(1.))
>> __main__:5: RuntimeWarning: invalid value encountered in sqrt
>>nan


出现错误是因为您的程序正在使用cmath库中的sqrt函数。这是一个很好的例子,说明为什么要避免(或者至少要小心)使用from library import *导入整个库。

例如,如果您反转import语句,那么您将没有问题:

1
2
3
4
5
6
7
8
9
10
11
from cmath import *
from numpy import *

def f(x):
    x = float_(x)
    return 1./(1.-x)


def g(x) :
    x = float_(x)
    return sqrt(-1.)/(1.-x)

现在g(1.)返回nan,因为使用numpysqrt函数(可以处理负数)。但这仍然是不好的做法:如果你有一个大文件,那么不清楚使用哪个sqrt

我建议始终使用命名导入,例如import numpy as npimport cmath as cm。然后不导入函数sqrt,要定义g(x),需要编写np.sqrt(或cm.sqrt)。

但是,您注意到您不想更改您的功能。在这种情况下,您应该只从您需要的每个库中导入这些函数,即

1
2
3
4
5
6
7
8
9
10
11
from cmath import sin  # plus whatever functions you are using in that file
from numpy import sqrt, float_

def f(x):
    x = float_(x)
    return 1./(1.-x)


def g(x) :
    x = float_(x)
    return sqrt(-1.)/(1.-x)

不幸的是,你不能轻易摆脱转换为numpy float_,因为python浮点数与numpy浮点数不同,所以这种转换需要在某个时刻发生。