数值稳定的softmax

Numercially stable softmax

下面有计算softmax函数的数值稳定方法吗? 我得到的值在神经网络代码中变成了 Nans。

np.exp(x)/np.sum(np.exp(y))

按照您的情况计算 softmax 函数没有任何问题。问题似乎来自 exploding gradient 或此类训练方法问题。通过“裁剪值”或“选择正确的权重初始分布”来关注这些问题。

softmax exp(x)/sum(exp(x)) 实际上在数值上表现良好。它只有正项,所以我们不用担心失去意义,而且分母至少和分子一样大,所以结果保证在0和1之间。

唯一可能发生的事故是指数过流或欠流。 x 的单个元素上溢或所有元素下溢将导致输出或多或少无用。

但是通过使用恒等式 softmax(x) = softmax(x + c) 可以很容易地防止这种情况发生对于任何标量 c:从 x 中减去 max(x) 留下一个只有非正项的向量,排除溢出和至少一个排除零分母的零元素(某些但并非所有条目中的下溢是无害的)。

脚注:理论上,总和中的灾难性事故是可能的,但您需要荒谬的 项。例如,即使使用只能解析 3 位小数的 16 位浮点数——与 "normal" 64 位浮点数的 15 位小数相比——我们需要 2^1431 (~6 x 10^431)和 2^1432 得到的总和是 .

感谢Paul Panzer's的解释,但我想知道为什么我们需要减去max(x)。因此,我找到了更详细的资料,希望对和我有同样问题的人有所帮助。 请参阅以下 link 文章中的 "What’s up with that max subtraction?" 部分。

https://nolanbconaway.github.io/blog/2017/softmax-numpy

Softmax 函数容易出现两个问题:溢出下溢

溢出:当非常大的数字被近似infinity

时发生

下溢:当非常小的数字(在数字行中接近零)被近似(即四舍五入)为zero

为了在进行 softmax 计算时解决这些问题,一个常见的技巧是通过 从所有元素中减去其中的最大元素 来移动输入向量。对于输入向量 x,定义 z 使得:

z = x-max(x)

然后取新(稳定)向量的softmax z


示例:

def stable_softmax(x):
    z = x - max(x)
    numerator = np.exp(z)
    denominator = np.sum(numerator)
    softmax = numerator/denominator

    return softmax

# input vector
In [267]: vec = np.array([1, 2, 3, 4, 5])
In [268]: stable_softmax(vec)
Out[268]: array([ 0.01165623,  0.03168492,  0.08612854,  0.23412166,  0.63640865])

# input vector with really large number, prone to overflow issue
In [269]: vec = np.array([12345, 67890, 99999999])
In [270]: stable_softmax(vec)
Out[270]: array([ 0.,  0.,  1.])

在上述情况下,我们通过使用stable_softmax()

安全地避免了溢出问题

有关详细信息,请参阅本书 Numerical Computation in deep learning 章。

扩展@kmario23 的答案以支持一维或二维 numpy 数组或列表(如果您通过 softmax 函数传递一批结果很常见):

import numpy as np


def stable_softmax(x):
    z = x - np.max(x, axis=-1, keepdims=True)
    numerator = np.exp(z)
    denominator = np.sum(numerator, axis=-1, keepdims=True)
    softmax = numerator / denominator
    return softmax


test1 = np.array([12345, 67890, 99999999])  # 1D numpy
test2 = np.array([[12345, 67890, 99999999], # 2D numpy
                  [123, 678, 88888888]])    #
test3 = [12345, 67890, 999999999]           # 1D list
test4 = [[12345, 67890, 999999999]]         # 2D list

print(stable_softmax(test1))
print(stable_softmax(test2))
print(stable_softmax(test3))
print(stable_softmax(test4))

 [0. 0. 1.]

[[0. 0. 1.]
 [0. 0. 1.]]

 [0. 0. 1.]

[[0. 0. 1.]]