计算 1/(1+exp(x)) python 时避免数值不稳定

Avoiding numerical instability when computing 1/(1+exp(x)) python

我想为(可能很大的)x 计算 1/(1+exp(x))。这是介于 0 和 1 之间的一个表现良好的函数。我可以做

import numpy as np
1.0/(1.0+np.exp(x))

但在这个天真的实现中,np.exp(x) 可能只是 return 0 或大 x 的无穷大,具体取决于符号。 python 中是否有可以帮助我解决问题的函数?

我正在考虑实现级数扩展和级数加速,但不知道这个问题是否已经解决了。

从根本上说,您受到浮点精度的限制。例如,如果您使用 64 位浮点数:

fmax_64 = np.finfo(np.float64).max  # the largest representable 64 bit float
print(np.log(fmax_64))
# 709.782712893

如果 x 大于 709,那么您将无法使用 64 位浮点数表示 np.exp(x)(或 1. / (1 + np.exp(x)))。

您可以使用扩展精度浮点数(即 np.longdouble):

fmax_long = np.finfo(np.longdouble).max
print(np.log(fmax_long))
# 11356.5234063

np.longdouble 的精度可能会因您的平台而异 - on x86 it is usually 80 bit,这将允许您使用最高约 11356 的 x 值:

func = lambda x: 1. / (1. + np.exp(np.longdouble(x)))
print(func(11356))
# 1.41861159972e-4932

除此之外,您需要重新考虑如何计算扩展,或者使用支持任意精度算术的 mpmath 之类的东西。然而,与 numpy 相比,这通常是以运行时性能更差为代价的,因为矢量化不再可能。

您可以使用 scipy.special.expit(-x)。它将避免 1.0/(1.0 + exp(x)).

产生的溢出警告