使用 python3 实现 FFT 和 IFFT

realize FFT and IFFT using python3

当我用FFT将两个大整数相乘时,发现FFT和IFFT的结果总是不对。

方法

为了实现FFT,我按照伪代码如下: the pseudocode of FFT

FFT和IFFT的方程如下。所以,在实现IFFT的时候,我就是把a换成y,把omega换成omega ^^ -1,再除以n。并且,在我的函数中使用flag来区分它们。

问题

为了找到问题,我尝试比较 numpy.fft 和我的函数之间的结果。

  1. FFT。 numpy 的结果和我的函数看起来是一样的,但是 images 的符号是相反的。例如(下面 case2 的第二个元素):
    • 我的函数结果:-4-9.65685424949238j
    • numpy 结果:-4+9.65685424949238j
  2. IFFT。我只是发现它不对,找不到任何规则。

python代码

这是我的函数FFT,和比较:

from typing import List
from cmath import pi, exp
from numpy.fft import fft, ifft


def FFT(a: List, flag: bool) -> List:
    """realize DFT using FFT"""
    n = len(a)
    if n == 1:
        return a

    # complex root
    omg_n = exp(2 * pi * 1j / n)
    if flag:
        # IFFT
        omg_n = 1 / omg_n
    omg = 1

    # split a into 2 part
    a0 = a[::2]  # even
    a1 = a[1::2]  # odd

    # corresponding y
    y0 = FFT(a0, flag)
    y1 = FFT(a1, flag)

    # result y
    y = [0] * n
    for k in range(n // 2):
        y[k] = y0[k] + omg * y1[k]
        y[k + n // 2] = y0[k] - omg * y1[k]
        omg = omg * omg_n

    # IFFT
    if flag:
        y = [i / n for i in y]
    return y

if __name__ == '__main__':
    test_cases = [
        [1, 1],
        [1, 2, 3, 4, 5, 6, 7, 8],
        [1, 4, 2, 9, 0, 0, 3, 8, 9, 1, 4, 0, 0, 0, 0, 0, ],
    ]

    print("test FFT")
    for i, case in enumerate(test_cases):
        print(f"case{i + 1}", case)
        manual_result = FFT(case, False)
        numpy_result = fft(case).tolist()
        print("manual_result:", manual_result)
        print("numpy_result:", numpy_result)
        print("difference:", [i - j for i, j in zip(manual_result, numpy_result)])
        print()



    print("test IFFT")
    for i, case in enumerate(test_cases):
        print(f"case{i + 1}", case)
        manual_result = FFT(case, True)
        numpy_result = ifft(case).tolist()
        print("manual_result:", manual_result)
        print("numpy_result:", numpy_result)
        print("difference:", [i - j for i, j in zip(manual_result, numpy_result)])
        print()

FFT输出:

test FFT
case1 [1, 1]
manual_result: [2, 0]
numpy_result: [(2+0j), 0j]
difference: [0j, 0j]

case2 [1, 2, 3, 4, 5, 6, 7, 8]
manual_result: [36, (-4-9.65685424949238j), (-4-4.000000000000001j), (-4-1.6568542494923815j), -4, (-4+1.6568542494923806j), (-4+4.000000000000001j), (-3.999999999999999+9.656854249492381j)]
numpy_result: [(36+0j), (-4+9.65685424949238j), (-4+4j), (-4+1.6568542494923806j), (-4+0j), (-4-1.6568542494923806j), (-4-4j), (-4-9.65685424949238j)]
difference: [0j, -19.31370849898476j, -8j, -3.313708498984762j, 0j, 3.313708498984761j, 8j, (8.881784197001252e-16+19.31370849898476j)]

case3 [1, 4, 2, 9, 0, 0, 3, 8, 9, 1, 4, 0, 0, 0, 0, 0]
manual_result: [41, (-12.710780677203363+13.231540329804117j), (12.82842712474619+7.2426406871192865j), (-14.692799048494296+7.4256307475248935j), (1.0000000000000013-12j), (5.763866860359768+6.0114171851517995j), (7.171572875253808+1.2426406871192839j), (-10.360287134662114+11.817326767431025j), -3, (-10.360287134662112-11.817326767431021j), (7.17157287525381-1.2426406871192848j), (5.763866860359771-6.011417185151798j), (0.9999999999999987+12j), (-14.692799048494292-7.425630747524895j), (12.828427124746192-7.242640687119286j), (-12.710780677203362-13.23154032980412j)]
numpy_result: [(41+0j), (-12.710780677203363-13.231540329804115j), (12.82842712474619-7.242640687119286j), (-14.692799048494292-7.4256307475248935j), (1+12j), (5.763866860359768-6.011417185151798j), (7.17157287525381-1.2426406871192857j), (-10.360287134662112-11.81732676743102j), (-3+0j), (-10.360287134662112+11.81732676743102j), (7.17157287525381+1.2426406871192857j), (5.763866860359768+6.011417185151798j), (1-12j), (-14.692799048494292+7.4256307475248935j), (12.82842712474619+7.242640687119286j), (-12.710780677203363+13.231540329804115j)]
difference: [0j, 26.46308065960823j, 14.485281374238571j, (-3.552713678800501e-15+14.851261495049787j), (1.3322676295501878e-15-24j), 12.022834370303597j, (-1.7763568394002505e-15+2.4852813742385695j), (-1.7763568394002505e-15+23.634653534862046j), 0j, -23.63465353486204j, -2.4852813742385704j, (3.552713678800501e-15-12.022834370303595j), (-1.3322676295501878e-15+24j), -14.851261495049789j, (1.7763568394002505e-15-14.485281374238571j), (1.7763568394002505e-15-26.463080659608238j)]

IFFT 结果:

test IFFT
case1 [1, 1]
manual_result: [1.0, 0.0]
numpy_result: [(1+0j), 0j]
difference: [0j, 0j]

case2 [1, 2, 3, 4, 5, 6, 7, 8]
manual_result: [0.5625, (-0.0625+0.15088834764831843j), (-0.0625+0.062499999999999986j), (-0.0625+0.025888347648318405j), -0.0625, (-0.0625-0.025888347648318433j), (-0.0625-0.062499999999999986j), (-0.062499999999999986-0.1508883476483184j)]
numpy_result: [(4.5+0j), (-0.5-1.2071067811865475j), (-0.5-0.5j), (-0.5-0.20710678118654757j), (-0.5+0j), (-0.5+0.20710678118654757j), (-0.5+0.5j), (-0.5+1.2071067811865475j)]
difference: [(-3.9375+0j), (0.4375+1.357995128834866j), (0.4375+0.5625j), (0.4375+0.23299512883486598j), (0.4375+0j), (0.4375-0.232995128834866j), (0.4375-0.5625j), (0.4375-1.357995128834866j)]

case3 [1, 4, 2, 9, 0, 0, 3, 8, 9, 1, 4, 0, 0, 0, 0, 0]
manual_result: [0.0400390625, (-0.01241287175508141-0.012921426103324331j), (0.012527760864009951-0.007072891296014926j), (-0.014348436570795205-0.007251592526879778j), (0.0009765625000000013+0.01171875j), (0.005628776230820083-0.005870524594874804j), (0.007003489135990047-0.0012135162960149274j), (-0.01011746790494347-0.011540358171319353j), -0.0029296875, (-0.010117467904943469+0.011540358171319355j), (0.007003489135990049+0.0012135162960149274j), (0.005628776230820081+0.005870524594874803j), (0.0009765624999999987-0.01171875j), (-0.014348436570795205+0.0072515925268797805j), (0.012527760864009953+0.007072891296014926j), (-0.012412871755081408+0.01292142610332433j)]
numpy_result: [(2.5625+0j), (-0.7944237923252102+0.8269712706127572j), (0.8017766952966369+0.45266504294495535j), (-0.9182999405308933+0.46410192172030584j), (0.0625-0.75j), (0.3602416787724855+0.37571357407198736j), (0.44822330470336313+0.07766504294495535j), (-0.647517945916382+0.7385829229644387j), (-0.1875+0j), (-0.647517945916382-0.7385829229644387j), (0.44822330470336313-0.07766504294495535j), (0.3602416787724855-0.37571357407198736j), (0.0625+0.75j), (-0.9182999405308933-0.46410192172030584j), (0.8017766952966369-0.45266504294495535j), (-0.7944237923252102-0.8269712706127572j)]
difference: [(-2.5224609375+0j), (0.7820109205701288-0.8398926967160816j), (-0.7892489344326269-0.45973793424097026j), (0.903951503960098-0.47135351424718563j), (-0.0615234375+0.76171875j), (-0.3546129025416654-0.38158409866686216j), (-0.4412198155673731-0.07887855924097029j), (0.6374004780114385-0.7501232811357581j), (0.1845703125+0j), (0.6374004780114385+0.7501232811357581j), (-0.4412198155673731+0.07887855924097029j), (-0.3546129025416654+0.38158409866686216j), (-0.0615234375-0.76171875j), (0.903951503960098+0.47135351424718563j), (-0.7892489344326269+0.45973793424097026j), (0.7820109205701288+0.8398926967160816j)]

@pjs,感谢您的提醒,FFT 要求 len(data) 是 2 的幂。

正如评论中指出的那样,您在 omg_n 的计算中使用了正号。 DFT 有不同的定义,所以它本身并没有错。但是,如果您将结果与使用负号的实现进行比较,这自然会导致差异,例如 numpy.fft.fft。将您的实现调整为也使用负号将涵盖所有正向变换情况(只留下大约 ~10-16)的小舍入误差。

对于逆变换情况,您的实施最终会在每个阶段将结果缩放 1/n,而不仅仅是最后阶段。要更正此问题,只需从递归中删除缩放比例,并仅在最后阶段进行归一化:

def FFTrecursion(a: List, flag: bool) -> List:
    """Recursion of the FFT implementation"""

    n = len(a)
    if n == 1:
        return a

    # complex root
    omg_n = exp(-2 * pi * 1j / n)
    if flag:
        # IFFT
        omg_n = 1 / omg_n
    omg = 1

    # split a into 2 part
    a0 = a[::2]  # even
    a1 = a[1::2]  # odd

    # corresponding y
    y0 = FFTrecursion(a0, flag)
    y1 = FFTrecursion(a1, flag)

    # result y
    y = [0] * n
    for k in range(n // 2):
        y[k] = y0[k] + omg * y1[k]
        y[k + n // 2] = y0[k] - omg * y1[k]
        omg = omg * omg_n

    return y


def FFT(a: List, flag: bool) -> List:
    """realize DFT using FFT"""

    y = FFTrecursion(a, flag)

    # IFFT final scaling
    if flag:
        n = len(a)
        y = [i / n for i in y]
    return y