在大数乘法上使用 FFT 的问题

Problems using FFT on multiplication of large numbers

我最近试图将一堆(最多 10^5)非常小的数字(10^-4 的数量级)相互相乘,所以我将以 10^-(4* 的数量级结束10^5) 不适合任何变量。

我的方法如下:

  1. 将每个数字乘以 10^8 并将其存储在一个数组中,该数组被 10 的幂分割,即我从中得出一个由 10 的幂给出的多项式。一个例子是:p = 0.1234 -> p*10^8 = 12340000 -> A={0, 0 ,0, 0, 4, 3, 2, 1}.
  2. 使用 FFT 将这些数组相乘
  3. iFFT 结果

这是针对少数不同案例多次完成的。 最后我想知道的是一个这样的产品占所有产品总和的分数,精度可达 10^-6。为此,在第 2 步和第 3 步之间,将结果添加到一个总和数组中,该数组最后也是 iFFT。由于要求的精度很低,所以我不对多项式进行除法,而是只取前几个整数。

我的问题如下:FFT and/or iFFT 无法正常工作!我是新手,只实施过一次 FFT。我的代码如下(C++14):

#include <cstdio>
#include <vector>
#include <cmath>
#include <complex>
#include <iostream>
using namespace std;

const double PI = 4*atan(1.0);

vector<complex<double>> FFT (vector<complex<double>> A, long N)
{
    vector<complex<double>> Ans(0);
    if(N==1) 
    {
        Ans.push_back(A[0]);
        return Ans;
    }
    vector<complex<double>> even(N/2);
    vector<complex<double>> odd(N/2);
    for(long i=0; i<(N/2); i++)
    {
        even[i] = A[2*i];
        odd[i] = A[2*i+1];
    }
    vector<complex<double>> L1 = FFT(even, N/2);
    vector<complex<double>> L2 = FFT(odd, N/2);
    for(long i=0; i<N; i++)
    {
        complex<double> z(cos(2*PI*i/N),sin(2*PI*i/N));
        long k = i%(N/2);
        Ans.push_back(L1[k] + z*L2[k]);
    }
    return Ans;
    }

    vector<complex<double>> iFFT (vector<complex<double>> A, long N)
    {
    vector<complex<double>> Ans(0);
    if(N==1) 
    {
        Ans.push_back(A[0]);
        return Ans;
    }
    vector<complex<double>> even(N/2);
    vector<complex<double>> odd(N/2);
    for(long i=0; i<(N/2); i++)
    {
        even[i] = A[2*i];
        odd[i] = A[2*i+1];
    }
    vector<complex<double>> L1 = FFT(even, N/2);
    vector<complex<double>> L2 = FFT(odd, N/2);
    for(long i=0; i<N; i++)
    {
        complex<double> z(cos(-2*PI*i/N),sin(-2*PI*i/N));
        complex<double> inv(double(1.0/N), 0);
        long k = i%(N/2);
        Ans.push_back(inv*(L1[k]+z*L2[k]));
    }
    return Ans;
}

vector<complex<double>> PMult (vector<complex<double>> A, vector<complex<double>> B, long L) 
{
vector<complex<double>> Ans(L);
for(int i=0; i<L; i++)
{
    Ans[i] = A[i]*B[i];
}
return Ans;
}

vector<complex<double>> DtoA (double x)
{
    vector<complex<double>> ans(8);
    long X = long(x*10000000);
    ans[0] = complex<double>(double(X%10), 0.0); X/=10;
    ans[1] = complex<double>(double(X%10), 0.0); X/=10;
    ans[2] = complex<double>(double(X%10), 0.0); X/=10;
    ans[3] = complex<double>(double(X%10), 0.0); X/=10;
    ans[4] = complex<double>(double(X%10), 0.0); X/=10;
    ans[5] = complex<double>(double(X%10), 0.0); X/=10;
    ans[6] = complex<double>(double(X%10), 0.0); X/=10;
    ans[7] = complex<double>(double(X%10), 0.0);
    return ans;
}

int main()
{
    vector<vector<complex<double>>> W;
    int T, N, M;
    double p;
    scanf("%d", &T);
    while( T-- )
    {
        scanf("%d %d", &N, &M);
        W.resize(N); 
        for(int i=0; i<N; i++)
        {
            cin >> p;
            W[i] = FFT(DtoA(p),8);
            for(int j=1; j<M; j++)
            {
                cin >> p;
                W[i] = PMult(W[i], FFT(DtoA(p),8), 8);
            }
        }
        vector<complex<double>> Sum(8);
        for(int j=0; j<8; j++) Sum[j]=W[0][j];
        for(int i=1; i<N; i++)
        {
            for(int j=0; j<8; j++)
            {
                Sum[j]+=W[i][j];
            }
        }

        W[0]=iFFT(W[0],8);
        Sum=iFFT(Sum, 8);

        long X=0;
        long Y=0;
        int a;
        for(a=0; Sum[a].real()!=0; a++);
        for(int i=a; i<8; i++)
        {
            Y*=10;
            Y=Y+long(Sum[i].real());
        }
        for(int i=a; i<8; i++)
        {
            X*=10;
            X=X+long(W[0][i].real()); 
        }
        double ans = 0.0;
        if(Y) ans=double(X)/double(Y);
        printf("%.7f\n", ans);
    }
}

我观察到的是,对于除一个条目外仅由零组成的数组,FFT returns 具有多个非空条目的数组。此外,在 iFFT 完成后,结果仍然包含虚部非零的条目。

有人可以找到错误或提供提示以简化解决方案吗?因为我希望它快,所以我不想做一个天真的乘法。 Karatsuba 的算法会不会更好,因为我不需要复数?

我检查了一个旧的 Java 我的 FFT 实现(抱歉代码很糟糕)。我在您的 FFT 和 DtoA 函数中发现了以下内容:

  • 您的 DtoA 函数中缺少 0
  • 您在向量中设置系数的顺序被颠倒了(可能出于某种原因这是故意的)
  • Cooley Tukey 算法的 "combine" 阶段不正确:前半部分的形式为 a + b * c,后半部分的形式为 a - b * c

下面的代码效率不高,但应该很清楚。

#include <iostream>
#include <vector>
#include <complex>
#include <exception>
#include <cmath>

using namespace std;

vector<complex<double>> fft(const vector<complex<double>> &x) {
    int N = x.size();

    // base case
    if (N == 1) return vector<complex<double>> { x[0] };

    // radix 2 Cooley-Tukey FFT
    if (N % 2 != 0) { throw new std::exception(); }

    // fft of even terms
    vector<complex<double>> even, odd;
    for (int k = 0; k < N/2; k++) {
        even.push_back(x[2 * k]);
        odd.push_back(x[2 * k + 1]);
    }
    vector<complex<double>> q = fft(even);
    vector<complex<double>> r = fft(odd);

    // combine
    vector<complex<double>> y;
    for (int k = 0; k < N/2; k++) {
        double kth = -2 * k * M_PI / N;
        complex<double> wk = complex<double>(cos(kth), sin(kth));
        y.push_back(q[k] + (wk * r[k]));
    }
    for (int k = 0; k < N/2; k++) {
        double kth = -2 * k * M_PI / N;
        complex<double> wk = complex<double>(cos(kth), sin(kth));
        y.push_back(q[k] - (wk * r[k])); // you didn't do this
    }
    return y;
}

vector<complex<double>> DtoA (double x)
{
    vector<complex<double>> ans(8);
    long X = long(x*100000000); // a 0 was missing  here
    ans[7] = complex<double>(double(X%10), 0.0); X/=10;
    ans[6] = complex<double>(double(X%10), 0.0); X/=10;
    ans[5] = complex<double>(double(X%10), 0.0); X/=10;
    ans[4] = complex<double>(double(X%10), 0.0); X/=10;
    ans[3] = complex<double>(double(X%10), 0.0); X/=10;
    ans[2] = complex<double>(double(X%10), 0.0); X/=10;
    ans[1] = complex<double>(double(X%10), 0.0); X/=10;
    ans[0] = complex<double>(double(X%10), 0.0);
    return ans;
}

int main ()
{
    double n = 0.1234;
    auto nComplex = DtoA(n);
    for (const auto &e : nComplex) {
        std::cout << e << " ";
    }
    std::cout << std::endl;

    try {
        auto nFFT = fft(nComplex);

        for (const auto &e : nFFT) {
            std::cout << e << " ";
        }
    }
    catch (const std::exception &e) {
        std::cout << "exception" << std::endl;
    }
  return 0;
}

程序输出(我用Octave查了一下,是一样的):

(1,0) (2,0) (3,0) (4,0) (0,0) (0,0) (0,0) (0,0) 
(10,0) (-0.414214,-7.24264) (-2,2) (2.41421,-1.24264) (-2,0) (2.41421,1.24264) (-2,-2) (-0.414214,7.24264)

希望这对您有所帮助。

编辑:

关于反FFT,你可以证明

iFFT(x) = (1 / N) conjugate( FFT( conjugate(x) ) )

其中 N 是数组 x 中的元素数。所以你可以使用 fft 函数来计算 ifft:

vector<complex<double>> ifft(const vector<complex<double>> &vec) {
    std::vector<complex<double>> conj;
    for (const auto &e : vec) {
        conj.push_back(std::conj(e));
    }

    std::vector<complex<double>> vecFFT = fft(conj);

    std::vector<complex<double>> result;
    for (const auto &e : vecFFT) {
        result.push_back(std::conj(e) / static_cast<double>(vec.size()));
    }

    return result;
}

这里是修改后的main:

int main ()
{
    double n = 0.1234;
    auto nComplex = DtoA(n);
    for (const auto &e : nComplex) {
        std::cout << e << " ";
    }
    std::cout << std::endl;

    auto nFFT = fft(nComplex);
    for (const auto &e : nFFT)
        std::cout << e << " ";
    std::cout << std::endl;

    auto iFFT = ifft(nFFT);
    for (const auto &e : iFFT)
        std::cout << e << " ";
    return 0;
}

和输出:

(1,0) (2,0) (3,0) (4,0) (0,0) (0,0) (0,0) (0,0) 
(10,0) (-0.414214,-7.24264) (-2,2) (2.41421,-1.24264) (-2,0) (2.41421,1.24264) (-2,-2) (-0.414214,7.24264) 
(1,-0) (2,-1.08163e-16) (3,7.4688e-17) (4,2.19185e-16) (0,-0) (0,1.13882e-16) (0,-7.4688e-17) (0,-2.24904e-16)

请注意,像 1e-16 这样的数字几乎为 0(双精度数在硬件上并不完美)。