FFT 使用递归 python 函数

FFT using recursive python function

我正在尝试使用给定列表的以下 code for finding FFT

经过大量试验后,我发现此代码仅适用于具有 2^m2^m+1 元素的输入列表。

您能否说明为什么会这样,以及是否可以修改它以使用包含其他数量元素的输入列表。 (P.S。我有一个包含 16001 个元素的输入列表)

    from cmath import exp, pi

    def fft(x):
        N = len(x)
        if N <= 1: return x
        even = fft(x[0::2])
        odd =  fft(x[1::2])
        T= [exp(-2j*pi*k/N)*odd[k] for k in xrange(N/2)]
        return [even[k] + T[k] for k in xrange(N/2)] + \
        [even[k] - T[k] for k in xrange(N/2)]

    print( ' '.join("%5.3f" % abs(f) 
            for f in fft([1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0])) )

编辑 1 你能说出前面和下面函数定义的区别吗:

def fft(x):
    N = len(x)
    T = exp(-2*pi*1j/N)
    if N > 1:
        x = fft(x[::2]) + fft(x[1::2])
        for k in xrange(N/2):
            xk = x[k]
            x[k] = xk + T**k*x[k+N/2]
            x[k+N/2] = xk - T**k*x[k+N/2]
    return x

编辑 2:实际上这段代码(在编辑 1 下)确实有效(对于之前缩进和变量命名的错误感到抱歉)这就是为什么我想了解两者之间的区别。(这适用于 16001元素也是!)

算法

此算法仅适用于输入数据的二次幂。 看看 theory.

这是检查此先决条件的改进版本:

from __future__ import print_function

import sys

if sys.version_info.major < 3:
    range = xrange

from cmath import exp, pi

def fft(x):
    N = len(x)
    if N <= 1: 
        return x
    if N % 2 > 0:
        raise ValueError("size of x must be a power of 2")
    even = fft(x[::2])
    odd =  fft(x[1::2])
    r = range(N//2)
    T = [exp(-2j * pi * k / N) * odd[k] for k in r]
    [even[k] for k in r]
    res = ([even[k] + T[k] for k in r] +
           [even[k] - T[k] for k in r])
    return res

input_data = [1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0]
print(' '.join("%5.3f" % abs(f) for f in fft(input_data)))

NumPy 附带了 FFT 函数的优化版本。 它适用于不同于 2 的幂的大小。 但它会慢很多,因为它需要应用不同的算法。

例如:

from numpy.fft import fft as np_fft
​
n = 1000
a = np.arange(n ** 2)
b = np.arange(n ** 2 - 10)

二次幂情况的运行时间:

%timeit np_fft(a)
10 loops, best of 3: 59.3 ms per loop

远低于以下情况:

%timeit np_fft(b)
1 loops, best of 3: 511 ms per loop

递归限制

Python 内置递归限制为 1000:

>>> import sys
>>> sys.getrecursionlimit()
1000

但是你可以增加递归限制:

sys.setrecursion(50000)

docs告诉你为什么:

getrecursionlimit()

Return the current value of the recursion limit, the maximum depth of the Python interpreter stack. This limit prevents infinite recursion from causing an overflow of the C stack and crashing Python. It can be set by setrecursionlimit().

setrecursionlimit()

Set the maximum depth of the Python interpreter stack to limit. This limit prevents infinite recursion from causing an overflow of the C stack and crashing Python.

The highest possible limit is platform-dependent. A user may need to set the limit higher when they have a program that requires deep recursion and a platform that supports a higher limit. This should be done with care, because a too-high limit can lead to a crash.

对编辑版本的回答

虽然这个版本:

from __future__ import print_function

import sys

if sys.version_info.major < 3:
    range = xrange

from cmath import exp, pi

def fft2(x):
    N = len(x)
    T = exp(-2*pi*1j/N)
    if N > 1:
        x = fft2(x[::2]) + fft2(x[1::2])
        for k in range(N//2):
            xk = x[k]
            x[k] = xk + T**k*x[k+N//2]
            x[k+N//2] = xk - T**k*x[k+N//2]
    return x

似乎适用于输入的二次方数:

import numpy as np
from numpy.fft import fft as np_fft

data = [1, 2, 3, 4]
np.allclose(fft2(data), np_fft(data))

True

对于不同数量的输入,它没有给出正确的结果。

data2 = [1, 2, 3, 4, 5]
np.allclose(fft2(data2), np_fft(data2))

False

它仍然基于输入的数量是2的幂的假设,即使它不会抛出异常。

在上面的代码中这一行:

T = [exp(-2j * pi * k / N) * odd[k] for k in r]

如果你仔细观察这部分: exp(-2j * pi * k / N) 会顺时针方向引导所有输出向量,对于正确的逆时针答案,将 exp(-2j * pi * k / N) 替换为 exp(2j * pi * k / N) 作为 omega = exp(2j * pi * k / N)