c++ Karatsuba 乘法使用向量

c++ Karatsuba Multiplication using Vectors

所以我一直在尝试为 Karatsuba 乘法算法编写一个算法,并且我一直在尝试使用向量作为我的数据结构来处理将要输入的非常长的数字...

我的程序可以很好地处理较小的数字,但它确实很难处理较大的数字,并且我得到了一个核心转储(Seg Fault)。当左侧数字小于右侧数字时,它也会输出奇怪的结果。

有什么想法吗?这是代码。

#include <iostream>
#include <string>
#include <vector>

#define max(a,b) ((a) > (b) ? (a) : (b))

using namespace std;

vector<int> add(vector<int> lhs, vector<int> rhs) {
    int length = max(lhs.size(), rhs.size());
    int carry = 0;
    int sum_col;
    vector<int> result;

    while(lhs.size() < length) {
        lhs.insert(lhs.begin(), 0);
    }
    while(rhs.size() < length) {
        rhs.insert(rhs.begin(), 0);
    }

    for(int i = length-1; i >= 0; i--) {
        sum_col = lhs[i] + rhs[i] + carry;
        carry = sum_col/10;
        result.insert(result.begin(), (sum_col%10));
    }
    if(carry) {
        result.insert(result.begin(), carry);
    }
    int x = 0;
    while(result[x] == 0) {
        x++;
    }
    result.erase(result.begin(), result.begin()+x);
    return result;
}

vector<int> subtract(vector<int> lhs, vector<int> rhs) {
    int length = max(lhs.size(), rhs.size());
    int diff;
    vector<int> result;

    while(lhs.size() < length) {
        lhs.insert(lhs.begin(), 0);
    }
    while(rhs.size() < length) {
        rhs.insert(rhs.begin(), 0);
    }

    for(int i = length-1; i >= 0; i--) {
        diff = lhs[i] - rhs[i];
        if(diff >= 0) {
            result.insert(result.begin(), diff);
        } else {
            int j = i - 1;
            while(j >= 0) {
                lhs[j] = (lhs[j] - 1) % 10;
                if(lhs[j] != 9) {
                    break;
                } else {
                    j--;
                }
            }
            result.insert(result.begin(), diff+10);
        }
    }
    int x = 0;
    while(result[x] == 0) {
        x++;
    }
    result.erase(result.begin(), result.begin()+x);
    return result;
}

vector<int> multiply(vector<int> lhs, vector<int> rhs) {
    int length = max(lhs.size(), rhs.size());
    vector<int> result;

    while(lhs.size() < length) {
        lhs.insert(lhs.begin(), 0);
    }
    while(rhs.size() < length) {
        rhs.insert(rhs.begin(), 0);
    }

    if(length == 1) {
        int res = lhs[0]*rhs[0];
        if(res >= 10) {
            result.push_back(res/10);
            result.push_back(res%10);
            return result;
        } else {
            result.push_back(res);
            return result;
        }
    }

    vector<int>::const_iterator first0 = lhs.begin();
    vector<int>::const_iterator last0 = lhs.begin() + (length/2);
    vector<int> lhs0(first0, last0);
    vector<int>::const_iterator first1 = lhs.begin() + (length/2);
    vector<int>::const_iterator last1 = lhs.begin() + ((length/2) + (length-length/2));
    vector<int> lhs1(first1, last1);
    vector<int>::const_iterator first2 = rhs.begin();
    vector<int>::const_iterator last2 = rhs.begin() + (length/2);
    vector<int> rhs0(first2, last2);
    vector<int>::const_iterator first3 = rhs.begin() + (length/2);
    vector<int>::const_iterator last3 = rhs.begin() + ((length/2) + (length-length/2));
    vector<int> rhs1(first3, last3);

    vector<int> p0 = multiply(lhs0, rhs0);
    vector<int> p1 = multiply(lhs1,rhs1);
    vector<int> p2 = multiply(add(lhs0,lhs1),add(rhs0,rhs1));
    vector<int> p3 = subtract(p2,add(p0,p1));

    for(int i = 0; i < 2*(length-length/2); i++) {
        p0.push_back(0);
    }
    for(int i = 0; i < (length-length/2); i++) {
        p3.push_back(0);
    }

    result = add(add(p0,p1), p3);

    int x = 0;
    while(result[x] == 0) {
        x++;
    }
    result.erase(result.begin(), result.begin()+x);
    return result;
}

int main() {
    vector<int> lhs;
    vector<int> rhs;
    vector<int> v;

    lhs.push_back(2);
    lhs.push_back(5);
    lhs.push_back(2);
    lhs.push_back(5);
    lhs.push_back(2);
    lhs.push_back(5);
    lhs.push_back(2);
    lhs.push_back(5);

    rhs.push_back(1);
    rhs.push_back(5);
    rhs.push_back(1);
    rhs.push_back(5);
    rhs.push_back(1);
    rhs.push_back(5);
    rhs.push_back(1);


    v = multiply(lhs, rhs);

    for(size_t i = 0; i < v.size(); i++) {
        cout << v[i];
    }
    cout << endl;

    return 0;
    }

subtract 存在几个问题。由于您无法表示负数,如果 rhs 大于 lhs,您的借用逻辑将在 lhs.[=28 的数据开始之前访问=]

如果结果为 0,您也可以在删除前导零时越过 result 的末尾。

你的借位计算是错误的,因为 -1 % 10 将 return -1,而不是 9,如果 lhs[j] 是 0。更好的计算方法是加 9(一小于您除以的值),lhs[j] = (lhs[j] + 9) % 10;.

在一个不相关的注释中,您可以简化范围迭代计算。由于 last0first1 具有相同的值,因此您可以对两者使用 last0,而 last1lhs.end()。这将 lhs1 简化为

vector<int> lhs1(last0, lhs.end());

你可以去掉 first1last1rhs 迭代器也是如此。