舍入误差:处理对具有非常小分量的向量的操作

Rounding errors: deal with operation on vectors with very small components

假设有一些向量(可以是 torch 张量或 numpy 数组),其中包含大量分量,每个分量都非常小 (~ 1e-10)。

假设我们要计算这些向量之一的范数(或其中两个向量之间的点积)。同样使用 float64 数据类型,每个分量的精度将为 ~1e-10,而 2 个分量的乘积(在 norm/dot 乘积计算期间)很容易达到 ~1e-20 导致很多舍入误差,加起来return一个错误的结果。

有没有办法处理这种情况? (例如,有没有办法为这些操作定义任意精度数组,或者一些内置的运算符自动处理这些操作?)

您在这里处理两个不同的问题:

下溢/上溢

计算平方时,计算非常小的值的范数可能会下溢为零。大值可能会溢出到无穷大。这可以通过使用稳定范数算法来解决。 处理此问题的一种简单方法是临时缩放值。参见例如:

a = np.array((1e-30, 2e-30), dtype='f4')
np.linalg.norm(a) # result is 0 due to underflow in single precision
scale = 1. / np.max(np.abs(a))
np.linalg.norm(a * scale) / scale # result is 2.236e-30

现在这是一个两次通过的算法,因为您必须在确定缩放值之前迭代所有数据。如果这不符合您的喜好,可以使用单遍算法,尽管您可能不想在 Python 中实现它们。经典的是 Blue 的算法: http://degiorgi.math.hr/~singer/aaa_sem/Float_Norm/p15-blue.pdf

一种更简单但效率更低的方法是简单地将调用链接到 hypot(它使用稳定的算法)。你不应该这样做,只是为了完成:

norm = 0.
for value in a:
    norm = math.hypot(norm, value)

甚至像这样的分层版本来减少 numpy 调用的次数:

norm = a
while len(norm) > 1:
    hlen = len(norm) >> 1
    front, back = norm[:hlen], norm[hlen: 2 * hlen]
    tail = norm[2 * hlen:] # only present with length is not even
    norm = np.append(np.hypot(front, back), tail)
norm = norm[0]

您可以自由组合这些策略。例如,如果您的数据不是一次性全部可用的,而是按块的(例如,因为数据集太大并且您从磁盘读取它),您可以为每个块选择一个缩放值,然后将块与一些链接在一起调用 hypot。

舍入误差

您会累积舍入误差,尤其是在累积不同大小的值时。如果你累积不同星座的价值,你也可能会经历灾难性的抵消。为避免这些问题,您需要使用补偿求和方案。 Python 与 math.fsum 一起提供了非常好的一个。 因此,如果您绝对需要最高精度,请使用以下内容:

math.sqrt(math.fsum(np.square(a * scale))) / scale

请注意,这对于一个简单的范数来说太过分了,因为累加中没有符号变化(因此没有抵消)并且平方会增加所有幅度差异,因此结果将始终由其最大的分量支配,除非你正在处理一个真正可怕的数据集。 numpy 没有为这些问题提供内置解决方案,这告诉您朴素算法实际上对于大多数现实世界的应用程序来说已经足够好了。在您真正 运行 陷入麻烦之前,没有理由过度实施。

应用于点积

我关注的是 l2 范数,因为这种情况通常被认为是危险的。当然,您可以将类似的策略应用于点积。

np.dot(a, b)

ascale = 1. / np.max(np.abs(a))
bscale = 1. / np.max(np.abs(b))

np.dot(a * ascale, b * bscale) / (ascale * bscale)

如果您使用混合精度,这将特别有用。例如,点积可以单精度计算,但 x / (ascale * bscale) 可以双精度甚至扩展精度计算。

当然 math.fsum 仍然可用:dot = math.fsum(a * b)

额外的想法

整个缩放本身会引入一些舍入误差,因为没有人向您保证 a/b 可以精确表示为浮点数。但是,您可以通过选择一个精确为 2 的幂的比例因子来避免这种情况。乘以 2 的幂在 FP 中始终是精确的(假设您保持在可表示的范围内)。您可以使用 math.frexp

获得指数