为 [-1, 1] 中的 c 计算 sqrt((b²*c²) / (1-c²)) 的数值稳定方法
Numerically stable way to compute sqrt((b²*c²) / (1-c²)) for c in [-1, 1]
对于 [-1, 1]
中的一些实际值 b
和 c
,我需要计算
sqrt( (b²*c²) / (1-c²) ) = (|b|*|c|) / sqrt((1-c)*(1+c))
当 c
接近 1 或 -1 时,分母中出现灾难性抵消。平方根可能也没有帮助。
我想知道我是否可以在这里应用一个聪明的技巧来避开 c=1 和 c=-1 周围的困难区域?
稳定性方面最有趣的部分是分母,sqrt(1 - c*c)
。为此,您需要做的就是将其扩展为 sqrt(1 - c) * sqrt(1 + c)
。我不认为这真的可以称为“聪明的把戏”,但这就是所需要的。
对于典型的二进制浮点格式(例如 IEEE 754 binary64,但其他常见格式应该表现得同样好,除了 double-double format), if c
is close to 1
then 1 - c
will be computed exactly, by Sterbenz' Lemma 这样令人不快的东西可能是例外,而 1 + c
没有任何稳定性问题。类似地,如果 c
接近 -1
,则 1 + c
将被准确计算,而 1 - c
将被准确计算。平方根和乘法运算不会引入重大的新错误。
这是一个数值演示,在具有 IEEE 754 二进制 64 浮点和正确舍入的 sqrt
运算的机器上使用 Python。
让我们取 c
接近(但小于)1
:
>>> c = float.fromhex('0x1.ffffffff24190p-1')
>>> c
0.9999999999
我们在这里必须小心一点:请注意,显示的十进制值 0.999999999
是 近似值 [=18= 的精确值].确切的值如十六进制字符串或分数形式的构造所示,562949953365017/562949953421312
,这就是我们关心获得良好结果的确切值。
表达式 sqrt(1 - c*c)
的精确值,四舍五入到小数点后 100 位,是:
0.0000141421362084401590649378320134409069878639187055610216016949959890888003204161068184484972504813
我使用 Python 的 decimal
module, and double-checked the result using Pari/GP 计算了这个。这是 Python 计算:
>>> from decimal import Decimal, getcontext
>>> getcontext().prec = 1000
>>> good = (1 - Decimal(c) * Decimal(c)).sqrt().quantize(Decimal("1e-100"))
>>> print(good)
0.0000141421362084401590649378320134409069878639187055610216016949959890888003204161068184484972504813
如果我们天真地计算,我们会得到这个结果:
>>> from math import sqrt
>>> naive = sqrt(1 - c*c)
>>> naive
1.4142136208793713e-05
我们可以很容易地计算出 ulps 错误的大致数量(对于正在进行的类型转换的数量表示歉意 - float
和 Decimal
实例不能直接在算术运算中混合):
>>> from math import ulp
>>> float((Decimal(naive) - good) / Decimal(ulp(float(good))))
208701.28298527992
所以天真的结果有几十万个 ulp - 粗略地说,我们已经失去了大约 5 个小数位的准确性。
现在让我们试试扩展版:
>>> better = sqrt(1 - c) * sqrt(1 + c)
>>> better
1.4142136208440158e-05
>>> float((Decimal(better) - good) / Decimal(ulp(float(good))))
-0.7170147200803595
所以这里我们的准确度优于 1 ulp 误差。不是完全正确的四舍五入,但次之。
通过更多的工作,假设 IEEE 754,应该可以在域 -1 < c < 1
上声明和证明表达式 sqrt(1 - c) * sqrt(1 + c)
中 ulp 错误数量的绝对上限二进制浮点、四舍五入模式和正确四舍五入的操作。我还没有这样做,但如果上限超过 10 ulp,我会感到非常惊讶。
Mark Dickinson 为一般情况提供了一个很好的 ,我将添加一些更专业的方法。
如今,许多计算环境都提供了一种称为融合乘加(简称 FMA)的运算,它是专门针对此类情况而设计的。在 fma(a, b, c)
的计算中,完整的乘积 a * b
(未截断和未舍入)进入与 c
的加法,然后在最后应用一次舍入。
目前出货的 GPU 和 CPU,包括基于 ARM64、x86-64 和 Power 架构的 GPU 和 CPU,通常包括 FMA 的快速硬件实现,它在 C 和 C++ 系列的编程语言以及许多其他作为标准数学函数 fma()
。一些——通常是较旧的——软件环境使用 FMA 的软件模拟,并且发现其中一些模拟有问题。此外,此类仿真往往很慢。
在 FMA 可用的情况下,表达式可以在数值上稳定并且没有过早上溢和下溢的风险,如 fabs (b * c) / sqrt (fma (c, -c, 1.0))
,其中 fabs()
是浮点操作数的绝对值运算和 sqrt()
计算平方根。某些环境还提供平方根倒数运算,通常称为 rsqrt()
,在这种情况下,可能的替代方法是使用 fabs (b * c) * rsqrt (fma (c, -c, 1.0))
。 rsqrt()
的使用避免了相对昂贵的除法,因此通常速度更快。但是,rsqrt()
的许多实现都没有像 sqrt()
那样正确四舍五入,因此准确性可能会稍差一些。
使用以下代码进行的快速实验似乎表明,只要 b
是 normal[=39=,基于 FMA 的表达式的最大误差约为 3 ulps ] 浮点数。我强调这 不 证明任何错误。自动 Herbie tool, which tries to find numerically advantageous rewrites of a given floating-point expression suggests 使用 fabs (b * c) * sqrt (1.0 / fma (c, -c, 1.0))
。然而,这似乎是一个虚假的结果,因为我既想不出任何特别的优势,也无法通过实验找到。
#include <stdio.h>
#include <stdlib.h>
#include <stdint.h>
#include <string.h>
#include <math.h>
#define USE_ORIGINAL (0)
#define USE_HERBIE (1)
/* function under test */
float func (float b, float c)
{
#if USE_HERBIE
return fabsf (b * c) * sqrtf (1.0f / fmaf (c, -c, 1.0f));
#else USE_HERBIE
return fabsf (b * c) / sqrtf (fmaf (c, -c, 1.0f));
#endif // USE_HERBIE
}
/* reference */
double funcd (double b, double c)
{
#if USE_ORIGINAL
double b2 = b * b;
double c2 = c * c;
return sqrt ((b2 * c2) / (1.0 - c2));
#else
return fabs (b * c) / sqrt (fma (c, -c, 1.0));
#endif
}
uint32_t float_as_uint32 (float a)
{
uint32_t r;
memcpy (&r, &a, sizeof r);
return r;
}
float uint32_as_float (uint32_t a)
{
float r;
memcpy (&r, &a, sizeof r);
return r;
}
uint64_t double_as_uint64 (double a)
{
uint64_t r;
memcpy (&r, &a, sizeof r);
return r;
}
double floatUlpErr (float res, double ref)
{
uint64_t i, j, err, refi;
int expoRef;
/* ulp error cannot be computed if either operand is NaN, infinity, zero */
if (isnan (res) || isnan (ref) || isinf (res) || isinf (ref) ||
(res == 0.0f) || (ref == 0.0f)) {
return 0.0;
}
/* Convert the float result to an "extended float". This is like a float
with 56 instead of 24 effective mantissa bits.
*/
i = ((uint64_t)float_as_uint32(res)) << 32;
/* Convert the double reference to an "extended float". If the reference is
>= 2^129, we need to clamp to the maximum "extended float". If reference
is < 2^-126, we need to denormalize because of the float types's limited
exponent range.
*/
refi = double_as_uint64(ref);
expoRef = (int)(((refi >> 52) & 0x7ff) - 1023);
if (expoRef >= 129) {
j = 0x7fffffffffffffffULL;
} else if (expoRef < -126) {
j = ((refi << 11) | 0x8000000000000000ULL) >> 8;
j = j >> (-(expoRef + 126));
} else {
j = ((refi << 11) & 0x7fffffffffffffffULL) >> 8;
j = j | ((uint64_t)(expoRef + 127) << 55);
}
j = j | (refi & 0x8000000000000000ULL);
err = (i < j) ? (j - i) : (i - j);
return err / 4294967296.0;
}
// Fixes via: Greg Rose, KISS: A Bit Too Simple. http://eprint.iacr.org/2011/007
static unsigned int z=362436069,w=521288629,jsr=362436069,jcong=123456789;
#define znew (z=36969*(z&0xffff)+(z>>16))
#define wnew (w=18000*(w&0xffff)+(w>>16))
#define MWC ((znew<<16)+wnew)
#define SHR3 (jsr^=(jsr<<13),jsr^=(jsr>>17),jsr^=(jsr<<5)) /* 2^32-1 */
#define CONG (jcong=69069*jcong+13579) /* 2^32 */
#define KISS ((MWC^CONG)+SHR3)
#define N (20)
int main (void)
{
float b, c, errloc_b, errloc_c, res;
double ref, err, maxerr = 0;
c = -1.0f;
while (c <= 1.0f) {
/* try N random values of `b` per every value of `c` */
for (int i = 0; i < N; i++) {
/* allow only normals */
do {
b = uint32_as_float (KISS);
} while (!isnormal (b));
res = func (b, c);
ref = funcd ((double)b, (double)c);
err = floatUlpErr (res, ref);
if (err > maxerr) {
maxerr = err;
errloc_b = b;
errloc_c = c;
}
}
c = nextafterf (c, INFINITY);
}
#if USE_HERBIE
printf ("HERBIE max ulp err = %.5f @ (b=% 15.8e c=% 15.8e)\n", maxerr, errloc_b, errloc_c);
#else // USE_HERBIE
printf ("SIMPLE max ulp err = %.5f @ (b=% 15.8e c=% 15.8e)\n", maxerr, errloc_b, errloc_c);
#endif // USE_HERBIE
return EXIT_SUCCESS;
}
对于 [-1, 1]
中的一些实际值 b
和 c
,我需要计算
sqrt( (b²*c²) / (1-c²) ) = (|b|*|c|) / sqrt((1-c)*(1+c))
当 c
接近 1 或 -1 时,分母中出现灾难性抵消。平方根可能也没有帮助。
我想知道我是否可以在这里应用一个聪明的技巧来避开 c=1 和 c=-1 周围的困难区域?
稳定性方面最有趣的部分是分母,sqrt(1 - c*c)
。为此,您需要做的就是将其扩展为 sqrt(1 - c) * sqrt(1 + c)
。我不认为这真的可以称为“聪明的把戏”,但这就是所需要的。
对于典型的二进制浮点格式(例如 IEEE 754 binary64,但其他常见格式应该表现得同样好,除了 double-double format), if c
is close to 1
then 1 - c
will be computed exactly, by Sterbenz' Lemma 这样令人不快的东西可能是例外,而 1 + c
没有任何稳定性问题。类似地,如果 c
接近 -1
,则 1 + c
将被准确计算,而 1 - c
将被准确计算。平方根和乘法运算不会引入重大的新错误。
这是一个数值演示,在具有 IEEE 754 二进制 64 浮点和正确舍入的 sqrt
运算的机器上使用 Python。
让我们取 c
接近(但小于)1
:
>>> c = float.fromhex('0x1.ffffffff24190p-1')
>>> c
0.9999999999
我们在这里必须小心一点:请注意,显示的十进制值 0.999999999
是 近似值 [=18= 的精确值].确切的值如十六进制字符串或分数形式的构造所示,562949953365017/562949953421312
,这就是我们关心获得良好结果的确切值。
表达式 sqrt(1 - c*c)
的精确值,四舍五入到小数点后 100 位,是:
0.0000141421362084401590649378320134409069878639187055610216016949959890888003204161068184484972504813
我使用 Python 的 decimal
module, and double-checked the result using Pari/GP 计算了这个。这是 Python 计算:
>>> from decimal import Decimal, getcontext
>>> getcontext().prec = 1000
>>> good = (1 - Decimal(c) * Decimal(c)).sqrt().quantize(Decimal("1e-100"))
>>> print(good)
0.0000141421362084401590649378320134409069878639187055610216016949959890888003204161068184484972504813
如果我们天真地计算,我们会得到这个结果:
>>> from math import sqrt
>>> naive = sqrt(1 - c*c)
>>> naive
1.4142136208793713e-05
我们可以很容易地计算出 ulps 错误的大致数量(对于正在进行的类型转换的数量表示歉意 - float
和 Decimal
实例不能直接在算术运算中混合):
>>> from math import ulp
>>> float((Decimal(naive) - good) / Decimal(ulp(float(good))))
208701.28298527992
所以天真的结果有几十万个 ulp - 粗略地说,我们已经失去了大约 5 个小数位的准确性。
现在让我们试试扩展版:
>>> better = sqrt(1 - c) * sqrt(1 + c)
>>> better
1.4142136208440158e-05
>>> float((Decimal(better) - good) / Decimal(ulp(float(good))))
-0.7170147200803595
所以这里我们的准确度优于 1 ulp 误差。不是完全正确的四舍五入,但次之。
通过更多的工作,假设 IEEE 754,应该可以在域 -1 < c < 1
上声明和证明表达式 sqrt(1 - c) * sqrt(1 + c)
中 ulp 错误数量的绝对上限二进制浮点、四舍五入模式和正确四舍五入的操作。我还没有这样做,但如果上限超过 10 ulp,我会感到非常惊讶。
Mark Dickinson 为一般情况提供了一个很好的
如今,许多计算环境都提供了一种称为融合乘加(简称 FMA)的运算,它是专门针对此类情况而设计的。在 fma(a, b, c)
的计算中,完整的乘积 a * b
(未截断和未舍入)进入与 c
的加法,然后在最后应用一次舍入。
目前出货的 GPU 和 CPU,包括基于 ARM64、x86-64 和 Power 架构的 GPU 和 CPU,通常包括 FMA 的快速硬件实现,它在 C 和 C++ 系列的编程语言以及许多其他作为标准数学函数 fma()
。一些——通常是较旧的——软件环境使用 FMA 的软件模拟,并且发现其中一些模拟有问题。此外,此类仿真往往很慢。
在 FMA 可用的情况下,表达式可以在数值上稳定并且没有过早上溢和下溢的风险,如 fabs (b * c) / sqrt (fma (c, -c, 1.0))
,其中 fabs()
是浮点操作数的绝对值运算和 sqrt()
计算平方根。某些环境还提供平方根倒数运算,通常称为 rsqrt()
,在这种情况下,可能的替代方法是使用 fabs (b * c) * rsqrt (fma (c, -c, 1.0))
。 rsqrt()
的使用避免了相对昂贵的除法,因此通常速度更快。但是,rsqrt()
的许多实现都没有像 sqrt()
那样正确四舍五入,因此准确性可能会稍差一些。
使用以下代码进行的快速实验似乎表明,只要 b
是 normal[=39=,基于 FMA 的表达式的最大误差约为 3 ulps ] 浮点数。我强调这 不 证明任何错误。自动 Herbie tool, which tries to find numerically advantageous rewrites of a given floating-point expression suggests 使用 fabs (b * c) * sqrt (1.0 / fma (c, -c, 1.0))
。然而,这似乎是一个虚假的结果,因为我既想不出任何特别的优势,也无法通过实验找到。
#include <stdio.h>
#include <stdlib.h>
#include <stdint.h>
#include <string.h>
#include <math.h>
#define USE_ORIGINAL (0)
#define USE_HERBIE (1)
/* function under test */
float func (float b, float c)
{
#if USE_HERBIE
return fabsf (b * c) * sqrtf (1.0f / fmaf (c, -c, 1.0f));
#else USE_HERBIE
return fabsf (b * c) / sqrtf (fmaf (c, -c, 1.0f));
#endif // USE_HERBIE
}
/* reference */
double funcd (double b, double c)
{
#if USE_ORIGINAL
double b2 = b * b;
double c2 = c * c;
return sqrt ((b2 * c2) / (1.0 - c2));
#else
return fabs (b * c) / sqrt (fma (c, -c, 1.0));
#endif
}
uint32_t float_as_uint32 (float a)
{
uint32_t r;
memcpy (&r, &a, sizeof r);
return r;
}
float uint32_as_float (uint32_t a)
{
float r;
memcpy (&r, &a, sizeof r);
return r;
}
uint64_t double_as_uint64 (double a)
{
uint64_t r;
memcpy (&r, &a, sizeof r);
return r;
}
double floatUlpErr (float res, double ref)
{
uint64_t i, j, err, refi;
int expoRef;
/* ulp error cannot be computed if either operand is NaN, infinity, zero */
if (isnan (res) || isnan (ref) || isinf (res) || isinf (ref) ||
(res == 0.0f) || (ref == 0.0f)) {
return 0.0;
}
/* Convert the float result to an "extended float". This is like a float
with 56 instead of 24 effective mantissa bits.
*/
i = ((uint64_t)float_as_uint32(res)) << 32;
/* Convert the double reference to an "extended float". If the reference is
>= 2^129, we need to clamp to the maximum "extended float". If reference
is < 2^-126, we need to denormalize because of the float types's limited
exponent range.
*/
refi = double_as_uint64(ref);
expoRef = (int)(((refi >> 52) & 0x7ff) - 1023);
if (expoRef >= 129) {
j = 0x7fffffffffffffffULL;
} else if (expoRef < -126) {
j = ((refi << 11) | 0x8000000000000000ULL) >> 8;
j = j >> (-(expoRef + 126));
} else {
j = ((refi << 11) & 0x7fffffffffffffffULL) >> 8;
j = j | ((uint64_t)(expoRef + 127) << 55);
}
j = j | (refi & 0x8000000000000000ULL);
err = (i < j) ? (j - i) : (i - j);
return err / 4294967296.0;
}
// Fixes via: Greg Rose, KISS: A Bit Too Simple. http://eprint.iacr.org/2011/007
static unsigned int z=362436069,w=521288629,jsr=362436069,jcong=123456789;
#define znew (z=36969*(z&0xffff)+(z>>16))
#define wnew (w=18000*(w&0xffff)+(w>>16))
#define MWC ((znew<<16)+wnew)
#define SHR3 (jsr^=(jsr<<13),jsr^=(jsr>>17),jsr^=(jsr<<5)) /* 2^32-1 */
#define CONG (jcong=69069*jcong+13579) /* 2^32 */
#define KISS ((MWC^CONG)+SHR3)
#define N (20)
int main (void)
{
float b, c, errloc_b, errloc_c, res;
double ref, err, maxerr = 0;
c = -1.0f;
while (c <= 1.0f) {
/* try N random values of `b` per every value of `c` */
for (int i = 0; i < N; i++) {
/* allow only normals */
do {
b = uint32_as_float (KISS);
} while (!isnormal (b));
res = func (b, c);
ref = funcd ((double)b, (double)c);
err = floatUlpErr (res, ref);
if (err > maxerr) {
maxerr = err;
errloc_b = b;
errloc_c = c;
}
}
c = nextafterf (c, INFINITY);
}
#if USE_HERBIE
printf ("HERBIE max ulp err = %.5f @ (b=% 15.8e c=% 15.8e)\n", maxerr, errloc_b, errloc_c);
#else // USE_HERBIE
printf ("SIMPLE max ulp err = %.5f @ (b=% 15.8e c=% 15.8e)\n", maxerr, errloc_b, errloc_c);
#endif // USE_HERBIE
return EXIT_SUCCESS;
}