有没有办法让这个功能更快? (C)

Is there a way to make this function faster? (C)

我在 C 中有一个代码,它以与人类相同的方式进行加法运算,因此,例如,如果我有两个数组 A[0..n-1]B[0..n-1],该方法将执行 C[0]=A[0]+B[0]C[1]=A[1]+B[1]...

我需要帮助使这个函数更快,即使解决方案使用的是内部函数。

我的主要问题是我有一个非常大的依赖性问题,因为迭代 i+1 取决于迭代 i 的进位,只要我使用基数 10。所以如果 A[0]=6B[0]=5, C[0] 必须是 1 我有一个进位 1 用于下一个加法。

我能做的更快的代码是这个:

void LongNumAddition1(unsigned char *Vin1, unsigned char *Vin2,
                      unsigned char *Vout, unsigned N) {
    for (int i = 0; i < N; i++) {
        Vout[i] = Vin1[i] + Vin2[i];
    } 

    unsigned char carry = 0;

    for (int i = 0; i < N; i++) {
        Vout[i] += carry;
        carry = Vout[i] / 10;
        Vout[i] = Vout[i] % 10;
    }
}

但我也尝试了这些方法,结果速度较慢:

void LongNumAddition1(unsigned char *Vin1, unsigned char *Vin2,
                      unsigned char *Vout, unsigned N) {
    unsigned char CARRY = 0;
    for (int i = 0; i < N; i++) {
        unsigned char R = Vin1[i] + Vin2[i] + CARRY;
        Vout[i] = R % 10; CARRY = R / 10;
    }
}

void LongNumAddition1(char *Vin1, char *Vin2, char *Vout, unsigned N) {
    char CARRY = 0;
    for (int i = 0; i < N; i++) {
        char R = Vin1[i] + Vin2[i] + CARRY;
        if (R <= 9) {
            Vout[i] = R;
            CARRY = 0;
        } else {
            Vout[i] = R - 10;
            CARRY = 1;
        }
    }
}

我一直在研究 google 并发现了一些与我实现的类似的伪代码,在 GeeksforGeeks 中也有针对此问题的另一种实现,但速度也较慢。

你能帮帮我吗?

提速候选人:

优化

确保您已启用编译器的速度优化设置。

restrict

编译器不知道更改 Vout[] 不会影响 Vin1[], Vin2[],因此在某些优化中受到限制。

使用restrict表示Vin1[], Vin2[]不受写入Vout[]的影响。

// void LongNumAddition1(unsigned char  *Vin1, unsigned char *Vin2, unsigned char *Vout, unsigned N)
void LongNumAddition1(unsigned char * restrict Vin1, unsigned char * restrict Vin2,
   unsigned char * restrict Vout, unsigned N)

注意:这会限制调用者使用与 Vin1, Vin2.

重叠的 Vout 调用函数

const

也使用 const 来帮助优化。 const 还允许 const 数组作为 Vin1, Vin2.

传递
// void LongNumAddition1(unsigned char * restrict Vin1, unsigned char * restrict Vin2,
   unsigned char * restrict Vout, unsigned N)
void LongNumAddition1(const unsigned char * restrict Vin1, 
   const unsigned char * restrict Vin2, 
   unsigned char * restrict Vout, 
   unsigned N)

unsigned

unsigned/int 是用于整数数学的 "goto" 类型。使用 <inttypes.h>.

中的 unsigned char CARRYchar CARRY,而不是 unsigneduint_fast8_t

%备选方案

sum = a+b+carry; if (sum >= 10) { sum -= 10; carry = 1; } else carry = 0; 之类的。


注意:我希望 LongNumAddition1() 到 return 最后进位。

为了提高 bignum 加法的速度,您应该将更多的十进制数字放入数组元素中。例如:您可以使用 uint32_t 而不是 unsigned char 并一次存储 9 个数字。

提高性能的另一个技巧是避免分支。

这是未经测试的代码修改版本:

void LongNumAddition1(const char *Vin1, const char *Vin2, char *Vout, unsigned N) {
    char carry = 0;
    for (int i = 0; i < N; i++) {
        char r = Vin1[i] + Vin2[i] + CARRY;
        carry = (r >= 10);
        Vout[i] = r - carry * 10;
    }
}

这是一次处理 9 位数字的修改版本:

#include <stdint.h>

void LongNumAddition1(const uint32_t *Vin1, const uint32_t *Vin2, uint32_t *Vout, unsigned N) {
    uint32_t carry = 0;
    for (int i = 0; i < N; i++) {
        uint32_t r = Vin1[i] + Vin2[i] + CARRY;
        carry = (r >= 1000000000);
        Vout[i] = r - carry * 1000000000;
    }
}

你可以在GodBolt's Compiler Explorer上查看gcc和clang生成的代码。

这是一个小测试程序:

#include <inttypes.h>
#include <stdio.h>
#include <stdint.h>
#include <string.h>

int LongNumConvert(const char *s, uint32_t *Vout, unsigned N) {
    unsigned i, len = strlen(s);
    uint32_t num = 0;
    if (len > N * 9)
        return -1;
    while (N * 9 > len + 8)
        Vout[--N] = 0;
    for (i = 0; i < len; i++) {
        num = num * 10 + (s[i] - '0');
        if ((len - i) % 9 == 1) {
            Vout[--N] = num;
            num = 0;
        }
    }
    return 0;
}

int LongNumPrint(FILE *fp, const uint32_t *Vout, unsigned N, const char *suff) {
    int len;
    while (N > 1 && Vout[N - 1] == 0)
        N--;
    len = fprintf(fp, "%"PRIu32"", Vout[--N]);
    while (N > 0)
        len += fprintf(fp, "%09"PRIu32"", Vout[--N]);
    if (suff)
        len += fprintf(fp, "%s", suff);
    return len;
}

void LongNumAddition(const uint32_t *Vin1, const uint32_t *Vin2,
                     uint32_t *Vout, unsigned N) {
    uint32_t carry = 0;
    for (unsigned i = 0; i < N; i++) {
        uint32_t r = Vin1[i] + Vin2[i] + carry;
        carry = (r >= 1000000000);
        Vout[i] = r - carry * 1000000000;
    }
}

int main(int argc, char *argv[]) {
    const char *sa = argc > 1 ? argv[1] : "123456890123456890123456890";
    const char *sb = argc > 2 ? argv[2] : "2035864230956204598237409822324";
#define NUMSIZE  111  // handle up to 999 digits
    uint32_t a[NUMSIZE], b[NUMSIZE], c[NUMSIZE];
    LongNumConvert(sa, a, NUMSIZE);
    LongNumConvert(sb, b, NUMSIZE);
    LongNumAddition(a, b, c, NUMSIZE);
    LongNumPrint(stdout, a, NUMSIZE, " + ");
    LongNumPrint(stdout, b, NUMSIZE, " = ");
    LongNumPrint(stdout, c, NUMSIZE, "\n");
    return 0;
}

在没有考虑特定系统的情况下讨论手动优化总是毫无意义的。如果我们假设您有某种带有数据缓存、指令缓存和分支预测的主流 32 位,那么:

  • 避免多重循环。您应该能够将它们合并为一个,从而获得重大的性能提升。这样你就不必多次接触同一个内存区域,你会减少分支的总数。每个 i < N 都必须由程序检查,因此减少检查量应该会提供更好的性能。此外,这可以提高数据缓存的可能性。

  • 对支持的最大对齐字大小执行所有操作。如果你有 32 bitter,你应该能够让这个算法一次处理 4 个字节,而不是一个字节一个字节地处理。这意味着以某种方式逐字节分配给 memcpy,一次执行 4 个字节。库质量代码就是这样做的。

  • 正确限定参数。您真的应该熟悉 常量正确性 这个术语。 Vin1Vin2 没有改变,所以它们应该是 const 并且不仅仅是为了性能,而是为了程序安全和 readability/maintainability.

  • 同样,如果你能保证参数没有指向重叠的内存区域,你就可以restrict限定所有指针。

  • 除法在许多 CPU 上是一项开销很大的操作,因此如果可以更改算法以摆脱 /%,那就去做吧。如果该算法是逐字节完成的,那么您可以牺牲 256 字节的内存来进行查找 table.

    (假设您可以在 ROM 中分配这样的查找 table 而无需引入等待状态依赖性等)。

  • 将进位更改为 32 位类型可能会在某些系统上提供更好的代码,而在其他系统上则更差。当我在 x86_64 上尝试这个时,它通过一条指令给出了稍微差一点的代码(非常小的差异)。

第一个循环

for (int i = 0; i < N; i++) {
    Vout[i] = Vin1[i] + Vin2[i];
} 

由编译器自动向量化。但是下一个循环

for (int i = 0; i < N; i++) {
    Vout[i] += carry;
    carry = Vout[i] / 10;
    Vout[i] = Vout[i] % 10;
}

包含一个 loop-carried dependence,它实质上序列化了整个循环(考虑将 1 添加到 99999999999999999 - 它只能按顺序计算,一次 1 位)。循环携带依赖是现代计算机科学中最令人头疼的问题之一。

这就是第一个版本更快的原因 - 它是部分矢量化的。任何其他版本都不是这种情况。

如何避免循环依赖?

计算机是 base-2 设备,在 base-10 算法方面出了名的糟糕。它不仅浪费 space,还会在每个数字之间造成人为的进位依赖关系。

如果您可以将数据从 base-10 转换为 base-2 表示,那么机器将两个数组相加会变得更加容易,因为机器可以轻松地在单次迭代中执行多个位的二进制加法。例如,对于 64 位机器,性能良好的表示可能是 uint64_t。请注意,带进位的流加法对于 SSE 仍然存在问题,但也存在一些选项。

不幸的是,C 编译器仍然很难生成带有进位传播的高效循环。出于这个原因,例如 libgmp 不是在 C 中而是在使用 ADC(带进位加法)指令的汇编语言中实现 bignum 加法。顺便说一句,libgmp 可以直接替代您项目中的许多 bignum 算术函数。

如果不想改变数据格式,可以试试SIMD。

typedef uint8_t u8x16 __attribute__((vector_size(16)));

void add_digits(uint8_t *const lhs, uint8_t *const rhs, uint8_t *out, size_t n) {
    uint8_t carry = 0;
    for (size_t i = 0; i + 15 < n; i += 16) {
        u8x16 digits = *(u8x16 *)&lhs[i] + *(u8x16 *)&rhs[i] + (u8x16){carry};

        // Get carries and almost-carries
        u8x16 carries = digits >= 10; // true is -1
        u8x16 full = digits == 9;

        // Shift carries
        carry = carries[15] & 1;
        __uint128_t carries_i = ((__uint128_t)carries) << 8;
        carry |= __builtin_add_overflow((__uint128_t)full, carries_i, &carries_i);

        // Add to carry chains and wrap
        digits += (((u8x16)carries_i) ^ full) & 1;
        // faster: digits = (u8x16)_mm_min_epu8((__m128i)digits, (__m128i)(digits - 10));
        digits -= (digits >= 10) & 10;

        *(u8x16 *)&out[i] = digits;
    }
}

这是每个数字 ~2 条指令。您需要添加代码来处理尾端。


这是算法的 运行-through。

首先,我们将我们的数字与上次迭代的进位相加:

lhs           7   3   5   9   9   2
rhs           2   4   4   9   9   7
carry                             1
         + -------------------------
digits        9   7   9  18  18  10

我们计算哪些数字会产生进位 (≥10),哪些会传播它们 (=9)。无论出于何种原因,对于 SIMD,true 为 -1。

carries       0   0   0  -1  -1  -1
full         -1   0  -1   0   0   0

我们将carries转为整数并移过来,同时将full转为整数。

              _   _   _   _   _   _
carries_i  000000001111111111110000
full       111100001111000000000000

现在我们可以将它们加在一起来传播进位。注意只有最低位是正确的。

              _   _   _   _   _   _
carries_i  111100011110111111110000
(relevant) ___1___1___0___1___1___0

有两个指标需要注意:

  1. carries_i 设置了最低位,digit ≠ 9。已经有进位进位了

  2. carries_i 已设置其最低位 un,并且 digit = 9这个方块上有一个进位,重置位。

我们用(((u8x16)carries_i) ^ full) & 1计算这个,然后加上digits

(c^f) & 1     0   1   1   1   1   0
digits        9   7   9  18  18  10
         + -------------------------
digits        9   8  10  19  19  10

然后我们去掉10,已经全部进位了

digits        9   8  10  19  19  10
(d≥10)&10     0   0  10  10  10  10
         - -------------------------
digits        9   8   0   9   9   0

我们还跟踪执行,这可能发生在两个地方。