具有参数约束的定点幂 (pow) 函数的高效实现

Efficient implementation of fixed-point power (pow) function with argument constraints

我正在寻找函数 pow(a, b) 的高效实现,其中 a 限于区间 (0,1),而 b>= 1(两者都是实数,即不一定是整数)。

如果有帮助的话,b 不是一个很大的数字——假设它小于 10-20。这将打开迭代解决这个问题的可能性,迭代次数少 ~= b

代码应该可以在 32 位微控制器上运行,可能没有浮点单元(即使用定点实现)。

我如何实现这样一个针对以下限制进行优化的功能?我正在寻找算法本身,所以伪代码是可以接受的。

OP 所述的问题未明确说明。需要什么样的精度,需要什么样的性能?由于已知 a 是 non-negative,suitable 的起点是将 ab 计算为 exp2 (b * log2 (a)),其中 fixed-point 这些功能的实现。根据我在嵌入式系统中使用 fixed-point 算法的经验,当通用 floating-point 计算是替换为 fixed-point 计算。所以我会在这里采用它。

在fixed-point中实现exp2()log2()的主要选择是按位计算、table中的二次插值和多项式逼近。后两者受益于可以产生 double-width 产品的硬件乘法器,并且 table-based 方法最适合几千字节的缓存。 OP 的微控制器规范表明代码和数据的资源可能是有限的,而且这完全是 low-end 硬件。所以 bit-wise 方法可能是一个很好的起点,因为它只需要在两个函数之间共享一个很小的 ​​table。代码大小也很小,特别是当编译器被指示针对代码大小进行优化时(-Os 大多数编译器)。

Bit-wise 这些函数的计算也称为 pseudo-division 和 pseudo-multiplication,并且与 Henry Briggs (1561 – 1630) 计算 tables 对数。

为了计算对数,我们通过反复试验从一系列特定因子中选择了那些,当与函数参数相乘时,将其归一化为单位。对于每个选择的因素,我们将存储在 table 中的相应对数值相加。然后,总和对应于函数参数的对数。对于整数部分,我们希望选择 2i,i = 1, 2, 4, 8, ... 对于派系部分,我们选择 1+2-i, i = 1, 2, 3, 4, 5 ...

指数算法与对数算法密切相关,并且完全相反。从我们选择的因素的列表对数序列中,我们选择那些当从函数参数中减去时最终将其减少到零的因素。从单位开始,我们乘以我们在过程中减去其对数的所有因子。得到的乘积对应于求幂的结果。

可以通过将算法应用于函数参数的倒数(根据需要对结果取反)来计算大于 1 的值的对数,同样,负函数参数求幂需要我们取反函数参数并计算最后互惠。因此,在这方面,这两种算法毫无疑问也是对称的。

以下 ISO-C99 代码是这两种算法的直接实现。请注意,此代码使用 fixed-point 乘法舍入。增量成本是最小的,它确实提高了整体 pow() 计算的准确性。

#include <stdio.h>
#include <stdlib.h>
#include <stdint.h>
#include <math.h>

/* s15.16 division without rounding */
int32_t fxdiv_s15p16 (int32_t x, int32_t y)
{
    return ((int64_t)x * 65536) / y;
}

/* s15.16 multiplication with rounding */
int32_t fxmul_s15p16 (int32_t x, int32_t y)
{
    int32_t r;
    int64_t t = (int64_t)x * (int64_t)y;
    r = (int32_t)(uint32_t)(((uint64_t)t + (1 << 15)) >> 16);
    return r;
}

/* log2(2**8), ..., log2(2**1), log2(1+2**(-1), ..., log2(1+2**(-16)) */
const uint32_t tab [20] = {0x80000, 0x40000, 0x20000, 0x10000,
                           0x095c1, 0x0526a, 0x02b80, 0x01663,
                           0x00b5d, 0x005b9, 0x002e0, 0x00170, 
                           0x000b8, 0x0005c, 0x0002e, 0x00017, 
                           0x0000b, 0x00006, 0x00003, 0x00001};
const int32_t one_s15p16 = 1 * (1 << 16);
const int32_t neg_fifteen_s15p16 = (-15) * (1 << 16);

int32_t fxlog2_s15p16 (int32_t a)
{
    uint32_t x, y;
    int32_t t, r;

    x = (a > one_s15p16) ? fxdiv_s15p16 (one_s15p16, a) : a;
    y = 0;
    /* process integer bits */
    if ((t = x << 8) < one_s15p16) { x = t; y += tab [0]; }
    if ((t = x << 4) < one_s15p16) { x = t; y += tab [1]; }
    if ((t = x << 2) < one_s15p16) { x = t; y += tab [2]; }
    if ((t = x << 1) < one_s15p16) { x = t; y += tab [3]; }
    /* process fraction bits */
    for (int shift = 1; shift <= 16; shift++) {
        if ((t = x + (x >> shift)) < one_s15p16) { x = t; y += tab[3 + shift]; }
    }
    r = (a > one_s15p16) ? y : (0 - y);
    return r;
}

int32_t fxexp2_s15p16 (int32_t a) 
{
    uint32_t x, y;
    int32_t t, r;

    if (a <= neg_fifteen_s15p16) return 0; // underflow

    x = (a < 0) ? (-a) : (a);
    y = one_s15p16;
    /* process integer bits */
    if ((t = x - tab [0]) >= 0) { x = t; y = y << 8; }
    if ((t = x - tab [1]) >= 0) { x = t; y = y << 4; }
    if ((t = x - tab [2]) >= 0) { x = t; y = y << 2; }
    if ((t = x - tab [3]) >= 0) { x = t; y = y << 1; }
    /* process fractional bits */
    for (int shift = 1; shift <= 16; shift++) {
        if ((t = x - tab [3 + shift]) >= 0) { x = t; y = y + (y >> shift); }
    }
    r = (a < 0) ? fxdiv_s15p16 (one_s15p16, y) : y;
    return r;
}

/* compute a**b for a >= 0 */
int32_t fxpow_s15p16 (int32_t a, int32_t b)
{
    return fxexp2_s15p16 (fxmul_s15p16 (b, fxlog2_s15p16 (a)));
}

double s15p16_to_double (int32_t a)
{
    return a / 65536.0;
}

int32_t double_to_s15p16 (double a)
{
    return ((int32_t)(a * 65536.0 + ((a < 0) ? (-0.5) : 0.5)));
}

int main (void) 
{
    int32_t a = double_to_s15p16 (0.125);
    do {
        int32_t b = double_to_s15p16 (-5);
        do {
            double fa = s15p16_to_double (a);
            double fb = s15p16_to_double (b);
            double reff = pow (fa, fb);
            int32_t res = fxpow_s15p16 (a, b);
            int32_t ref = double_to_s15p16 (reff);
            double fres = s15p16_to_double (res);
            double fref = s15p16_to_double (ref);
            printf ("a=%08x (%15.8e) b=%08x (% 15.8e) res=%08x (% 15.8e) ref=%08x (%15.8e)\n", 
                    a, fa, b, fb, res, fres, ref, fref);
            b += double_to_s15p16 (0.5);
        } while (b <= double_to_s15p16 (6));
        printf ("\n");
        a += double_to_s15p16 (0.125);
    } while (a <= double_to_s15p16 (1.0));
    return EXIT_SUCCESS;
}

对于具有合理数量缓存和快速整数乘法的处理器,我们可以在 table 秒内使用二次插值来对 log2()exp2() 进行非常准确和快速的计算。为此,我们使用伪 floating-point 表示,其中参数 x = 2i(1+f), [0,1] 中有 f。使用 32 位的完整字大小对分数 f 进行归一化。对于对数,我们将落入 [0, 1) 的 log2(1+f) 制表。对于指数,我们将 exp2(f)-1 制表,它也落入 [0,1)。显然我们需要将 table 插值的结果加 1。

对于二次插值,我们总共需要三个 table 项来拟合抛物线,其系数 ab 可以即时计算。函数参数和最近的节点之间的差异为dx,我们可以添加一个(dx)2+b(dx) 到相应的 table 条目进行插值。需要三个连续的 table 条目来拟合抛物线,这也意味着我们需要将超出主要插值区间的两个附加值制成表格。由于使用fixed-point算法进行中间计算而导致的不准确可以通过在末尾添加一个舍入常数来部分补偿,这需要通过实验确定。

    /* for i = 0 to 65: log2 (1 + i/64) * 2**31 */
    uint32_t logTab [66] =
    {
        0x00000000, 0x02dcf2d1, 0x05aeb4dd, 0x08759c50, 
        0x0b31fb7d, 0x0de42120, 0x108c588d, 0x132ae9e2, 
        0x15c01a3a, 0x184c2bd0, 0x1acf5e2e, 0x1d49ee4c, 
        0x1fbc16b9, 0x22260fb6, 0x24880f56, 0x26e2499d, 
        0x2934f098, 0x2b803474, 0x2dc4439b, 0x30014ac6, 
        0x32377512, 0x3466ec15, 0x368fd7ee, 0x38b25f5a, 
        0x3acea7c0, 0x3ce4d544, 0x3ef50ad2, 0x40ff6a2e, 
        0x43041403, 0x450327eb, 0x46fcc47a, 0x48f10751, 
        0x4ae00d1d, 0x4cc9f1ab, 0x4eaecfeb, 0x508ec1fa, 
        0x5269e12f, 0x5440461c, 0x5612089a, 0x57df3fd0, 
        0x59a80239, 0x5b6c65aa, 0x5d2c7f59, 0x5ee863e5, 
        0x60a02757, 0x6253dd2c, 0x64039858, 0x65af6b4b, 
        0x675767f5, 0x68fb9fce, 0x6a9c23d6, 0x6c39049b, 
        0x6dd2523d, 0x6f681c73, 0x70fa728c, 0x72896373, 
        0x7414fdb5, 0x759d4f81, 0x772266ad, 0x78a450b8, 
        0x7a231ace, 0x7b9ed1c7, 0x7d17822f, 0x7e8d3846, 
        0x80000000, 0x816fe50b
    };
    
    /* for i = 0 to 129: exp2 ((i/128) - 1) * 2**31 */
    uint32_t expTab [130] =
    {
        0x00000000, 0x00b1ed50, 0x0164d1f4, 0x0218af43,
        0x02cd8699, 0x0383594f, 0x043a28c4, 0x04f1f656,
        0x05aac368, 0x0664915c, 0x071f6197, 0x07db3580,
        0x08980e81, 0x0955ee03, 0x0a14d575, 0x0ad4c645,
        0x0b95c1e4, 0x0c57c9c4, 0x0d1adf5b, 0x0ddf0420,
        0x0ea4398b, 0x0f6a8118, 0x1031dc43, 0x10fa4c8c,
        0x11c3d374, 0x128e727e, 0x135a2b2f, 0x1426ff10,
        0x14f4efa9, 0x15c3fe87, 0x16942d37, 0x17657d4a,
        0x1837f052, 0x190b87e2, 0x19e04593, 0x1ab62afd,
        0x1b8d39ba, 0x1c657368, 0x1d3ed9a7, 0x1e196e19,
        0x1ef53261, 0x1fd22825, 0x20b05110, 0x218faecb,
        0x22704303, 0x23520f69, 0x243515ae, 0x25195787,
        0x25fed6aa, 0x26e594d0, 0x27cd93b5, 0x28b6d516,
        0x29a15ab5, 0x2a8d2653, 0x2b7a39b6, 0x2c6896a5,
        0x2d583eea, 0x2e493453, 0x2f3b78ad, 0x302f0dcc,
        0x3123f582, 0x321a31a6, 0x3311c413, 0x340aaea2,
        0x3504f334, 0x360093a8, 0x36fd91e3, 0x37fbefcb,
        0x38fbaf47, 0x39fcd245, 0x3aff5ab2, 0x3c034a7f,
        0x3d08a39f, 0x3e0f680a, 0x3f1799b6, 0x40213aa2,
        0x412c4cca, 0x4238d231, 0x4346ccda, 0x44563ecc,
        0x45672a11, 0x467990b6, 0x478d74c9, 0x48a2d85d,
        0x49b9bd86, 0x4ad2265e, 0x4bec14ff, 0x4d078b86,
        0x4e248c15, 0x4f4318cf, 0x506333db, 0x5184df62,
        0x52a81d92, 0x53ccf09a, 0x54f35aac, 0x561b5dff,
        0x5744fccb, 0x5870394c, 0x599d15c2, 0x5acb946f,
        0x5bfbb798, 0x5d2d8185, 0x5e60f482, 0x5f9612df,
        0x60ccdeec, 0x62055b00, 0x633f8973, 0x647b6ca0,
        0x65b906e7, 0x66f85aab, 0x68396a50, 0x697c3840,
        0x6ac0c6e8, 0x6c0718b6, 0x6d4f301f, 0x6e990f98,
        0x6fe4b99c, 0x713230a8, 0x7281773c, 0x73d28fde,
        0x75257d15, 0x767a416c, 0x77d0df73, 0x792959bb,
        0x7a83b2db, 0x7bdfed6d, 0x7d3e0c0d, 0x7e9e115c,
        0x80000000, 0x8163daa0,
    };
    
    uint8_t clz_tab[32] = {
        31, 22, 30, 21, 18, 10, 29,  2, 20, 17, 15, 13, 9,  6, 28, 1,
        23, 19, 11,  3, 16, 14,  7, 24, 12,  4,  8, 25, 5, 26, 27, 0};
    
    /* count leading zeros; this is a machine instruction on many architectures */
    int32_t clz (uint32_t a)
    {
        a |= a >> 16;
        a |= a >> 8;
        a |= a >> 4;
        a |= a >> 2;
        a |= a >> 1;
        return clz_tab [0x07c4acdd * a >> 27];
    }
    
    int32_t fxlog2_s15p16 (int32_t arg)
    {
        int32_t lz, f1, f2, dx, a, b, approx;
        uint32_t t, idx, x = arg;
        lz = clz (x);
        /* shift off integer bits and normalize fraction 0 <= f <= 1 */
        t = x << (lz + 1);
        /* index table of values log2 (1 + f) using 6 msbs of fraction */
        idx = (unsigned)t >> (32 - 6);
        /* difference between argument and smallest sampling point */
        dx = t - (idx << (32 - 6));
        /* fit parabola through closest three sampling points; find coeffs a, b */
        f1 = (logTab[idx+1] - logTab[idx]);
        f2 = (logTab[idx+2] - logTab[idx]);
        a = f2 - (f1 << 1);
        b = (f1 << 1) - a;
        /* find function value for argument by computing ((a*dx+b)*dx) */
        approx = (int32_t)((((int64_t)a)*dx) >> (32 - 6)) + b;
        approx = (int32_t)((((int64_t)approx)*dx) >> (32 - 6 + 1));
        /* compute fraction result; add experimentally determined rounding constant */
        approx = logTab[idx] + approx + 0x410d;
        /* combine integer and fractional parts of result */
        approx = ((15 - lz) << 16) + (((uint32_t)approx) >> 15);
        return approx;
    }
    
    int32_t fxexp2_s15p16 (int32_t x)
    {
        int32_t f1, f2, dx, a, b, approx, idx, i, f;
    
        /* extract integer portion; 2**i is realized as a shift at the end */
        i = (x >> 16);
        /* extract fraction f so we can compute 2**(f)-1, 0 <= f < 1 */
        f = x & 0xffff;
        /* index table of values exp2 (f) - 1 using 7 msbs of fraction */
        idx = (uint32_t)f >> (16 - 7);
        /* difference between argument and next smaller sampling point */
        dx = f - (idx << (16 - 7));
        /* fit parabola through closest three sampling point; find coeffs a,b */
        f1 = (expTab[idx+1] - expTab[idx]);
        f2 = (expTab[idx+2] - expTab[idx]);
        a = f2 - (f1 << 1);
        b = (f1 << 1) - a;
        /* find function value for argument by computing ((a*dx+b)*dx) */
        approx = (int32_t)((((int64_t)a)*dx) >> (16 - 7)) + b;
        approx = (int32_t)((((int64_t)approx)*dx) >> (16 - 7 + 1));
        /* combine integer and fractional parts of result; add 1.0 and experimentally determined rounding constant */
        approx = (((expTab[idx] + (unsigned)approx + 0x80000012U) >> (14 - i)) + 1) >> 1;
        /* Handle underflow to 0 */
        approx = ((i < -16) ? 0 : approx);
        return approx;
    }

最后,在数据存储有限但提供非常快的整数乘法器的平台上,可以使用极小极大类型的多项式近似,即最小化最大误差的多项式近似。此代码将类似于 table-based 方法,除了两个核心近似值,log2(1+f) 和 exp2(f)-1 for f in [0,1),是通过多项式计算的table 秒。为了获得最大的准确性和代码效率,我们希望将系数表示为 unsigned pur分数,即使用 u0.32 fixed-point 格式。在 log2() 的情况下,前导系数为 ~= 1.44,因此不适合这种格式。然而,通过将这个系数分成两部分,1 和 0.44,这很容易解决。对于 fixed-point 通过 Horner 方案进行的计算评估通常是不必要的,并且可以通过 Estrin-like 方案实现流水线乘法器的最佳使用以及 instruction-level 并行性的增加,如

Claude-Pierre Jeannerod, Jingyan Jourdan-Lu, “VLIW 整数处理器的同步 floating-point 正弦和余弦”。在 23rd IEEE 国际会议 Application-Specific 系统、架构和处理器,2012 年 7 月,第 69-76 页(preprint online)

    /* a single instruction in many 32-bit architectures */
    uint32_t umul32hi (uint32_t a, uint32_t b)
    {
        return (uint32_t)(((uint64_t)a * b) >> 32);
    }
    
    uint8_t clz_tab[32] = {
            31, 22, 30, 21, 18, 10, 29,  2, 20, 17, 15, 13, 9,  6, 28, 1,
            23, 19, 11,  3, 16, 14,  7, 24, 12,  4,  8, 25, 5, 26, 27, 0};
    
    /* count leading zeros; this is a machine instruction on many architectures */
    int32_t clz (uint32_t a)
    {
        a |= a >> 16;
        a |= a >> 8;
        a |= a >> 4;
        a |= a >> 2;
        a |= a >> 1;
        return clz_tab [0x07c4acdd * a >> 27];
    }
    
    /* compute log2() with s15.16 fixed-point argument and result */
    int32_t fxlog2_s15p16 (int32_t arg)
    {
        const uint32_t a0 = (uint32_t)((1.44269476063 - 1)* (1LL << 32) + 0.5);
        const uint32_t a1 = (uint32_t)(7.2131008654833e-1 * (1LL << 32) + 0.5);
        const uint32_t a2 = (uint32_t)(4.8006370104849e-1 * (1LL << 32) + 0.5);
        const uint32_t a3 = (uint32_t)(3.5339481476694e-1 * (1LL << 32) + 0.5);
        const uint32_t a4 = (uint32_t)(2.5600972794928e-1 * (1LL << 32) + 0.5);
        const uint32_t a5 = (uint32_t)(1.5535182948224e-1 * (1LL << 32) + 0.5);
        const uint32_t a6 = (uint32_t)(6.3607925549150e-2 * (1LL << 32) + 0.5);
        const uint32_t a7 = (uint32_t)(1.2319647939876e-2 * (1LL << 32) + 0.5);
        int32_t lz;
        uint32_t approx, h, m, l, z, y, x = arg;
        lz = clz (x);
        /* shift off integer bits and normalize fraction 0 <= f <= 1 */
        x = x << (lz + 1);
        y = umul32hi (x, x); // f**2
        z = umul32hi (y, y); // f**4
        /* evaluate minimax polynomial to compute log2(1+f) */
        h = a0 - umul32hi (a1, x);
        m = umul32hi (a2 - umul32hi (a3, x), y);
        l = umul32hi (a4 - umul32hi (a5, x) + umul32hi(a6 - umul32hi(a7, x), y), z);
        approx = x + umul32hi (x, h + m + l);
        /* combine integer and fractional parts of result; round result */
        approx = ((15 - lz) << 16) + ((((approx) >> 15) + 1) >> 1);
        return approx;
    }

    /* compute exp2() with s15.16 fixed-point argument and result */
    int32_t fxexp2_s15p16 (int32_t arg)
    {
        const uint32_t a0 = (uint32_t)(6.9314718107e-1 * (1LL << 32) + 0.5);
        const uint32_t a1 = (uint32_t)(2.4022648809e-1 * (1LL << 32) + 0.5);
        const uint32_t a2 = (uint32_t)(5.5504413787e-2 * (1LL << 32) + 0.5);
        const uint32_t a3 = (uint32_t)(9.6162736882e-3 * (1LL << 32) + 0.5);
        const uint32_t a4 = (uint32_t)(1.3386828359e-3 * (1LL << 32) + 0.5);
        const uint32_t a5 = (uint32_t)(1.4629773796e-4 * (1LL << 32) + 0.5);
        const uint32_t a6 = (uint32_t)(2.0663021132e-5 * (1LL << 32) + 0.5);
        int32_t i;
        uint32_t approx, h, m, l, s, q, f, x = arg;
    
        /* extract integer portion; 2**i is realized as a shift at the end */
        i = ((x >> 16) ^ 0x8000) - 0x8000;
        /* extract and normalize fraction f to compute 2**(f)-1, 0 <= f < 1 */
        f = x << 16;
        /* evaluate minimax polynomial to compute exp2(f)-1 */
        s = umul32hi (f, f); // f**2
        q = umul32hi (s, s); // f**4
        h = a0 + umul32hi (a1, f);
        m = umul32hi (a2 + umul32hi (a3, f), s);
        l = umul32hi (a4 + umul32hi (a5, f) + umul32hi (a6, s), q);
        approx = umul32hi (f, h + m + l);
        /* combine integer and fractional parts of result; round result */
        approx = ((approx >> (15 - i)) + (0x80000000 >> (14 - i)) + 1) >> 1;
        /* handle underflow to 0 */
        approx = ((i < -16) ? 0 : approx);
        return approx;
    }