fft 算法产生不精确的结果
fft algorithm yields imprecise results
我正在尝试基于 dft(离散傅立叶变换)矩阵实现 fft(快速傅立叶变换)factorization.In 以下代码,包括 fft 和直接方法(即:直接乘以 dft 矩阵与 v) 的实施是为了测试我实施 fft 的有效性。
import numpy as n
import cmath, math
import matplotlib.pyplot as plt
v=n.array([1,-1,2,-3])
w=v
N=len(v)
t=[0]*N
M=n.zeros((N,N),dtype=complex)
z=n.exp(2j*math.pi/N)
for a in range(N):
for b in range(N):
M[a][b]=n.exp(2j*math.pi*a*b/N)
print (n.dot(v,M))
plt.plot(n.dot(v,M))
def f(x):
x=n.concatenate([x[::2],x[1::2]])
return x
while (w!=f(v)).any():
v=f(v)
print(v)
a=2
while a<=N:
for k in range(N/a):
for y in range(a/2):
t[y]=v[a*k+y]
for i in range(a/2):
v[a*k+i]+=v[a*k+i+a/2]*(z**i)
v[a*k+i+a/2]=t[i]-v[a*k+i+a/2]*(z**i)
a*=2
print(v)
plt.plot(v)
plt.show()
我用很多 v 值尝试过,有时这两种方法的输出产生完全相同的结果,但有时它们彼此接近但不完全相同。经过几次测试后,他们还没有远离彼此,每个测试都有不同的 v 值。
有什么我遗漏的导致代码不精确的地方吗?
编辑:
请注意,该代码是为 Python 2 设计的(因为隐式整数除法)。
看来问题不在算法上,而是在v的声明上(感谢@kazemakase)。尝试
v=n.array([1,-1,2,-3], dtype=complex)
相反。至少对我来说,曲线然后出现在彼此之上:
编辑
这真是一段旅程。我无法弄清楚您的代码有什么问题,但看起来有几个错误,包括 dft 和 fft。最后,我根据 [this document] (http://www.cs.cmu.edu/afs/andrew/scs/cs/15-463/2001/pub/www/notes/fourier/fourier.pdf) (pages 6 -- 9 hold all the information you need). Maybe you can go through the algorithm and figure out where your problems lie. The algorithm for the bit reversal can be found in this answer (or alternatively in this one ) 编写了自己的 fft 版本。我测试了不同长度的线性向量的代码——如果您发现任何错误,请告诉我。
import numpy as np
import cmath
def bit_reverse(x,n):
"""
Reverse the last n bits of x
"""
##from
##formstr = '{{:0{}b}}'.format(n)
##return int(formstr.format(x)[::-1],2)
##from
return sum(1<<(n-1-i) for i in range(n) if x>>i&1)
def permute_vector(v):
"""
Permute vector v such that the indices of the result
correspond to the bit-reversed indices of the original.
Returns the permuted input vector and the number of bits used.
"""
##check that len(v) == 2**n
##and at the same time find permutation length:
L = len(v)
comp = 1
bits = 0
while comp<L:
comp *= 2
bits += 1
if comp != L:
raise ValueError('permute_vector: wrong length of v -- must be 2**n')
rindices = [bit_reverse(i,bits)for i in range(L)]
return v[rindices],bits
def dft(v):
N = v.shape[0]
a,b = np.meshgrid(
np.linspace(0,N-1,N,dtype=np.complex128),
np.linspace(0,N-1,N,dtype=np.complex128),
)
M = np.exp((-2j*np.pi*a*b)/N)
return np.dot(M,v)
def fft(v):
w,bits = permute_vector(v)
N = w.shape[0]
z=np.exp(np.array(-2j,dtype=np.complex128)*np.pi/N)
##starting fft
for i in range(bits):
dist = 2**i ##distance between 'exchange pairs'
group = dist*2 ##size of sub-groups
for start in range(0,N,group):
for offset in range(group//2):
pos1 = start+offset
pos2 = pos1+dist
alpha1 = z**((pos1*N//group)%N)
alpha2 = z**((pos2*N//group)%N)
w[pos1],w[pos2] = w[pos1]+alpha1*w[pos2],w[pos1]+alpha2*w[pos2]
return w
if __name__ == '__main__':
#test the fft
for n in [2**i for i in range(1,5)]:
print('-'*25+'n={}'.format(n)+'-'*25)
v = np.linspace(0,n-1,n, dtype=np.complex128)
print('v = ')
print(v)
print('fft(v) = ')
print(fft(v))
print('dft(v) = ')
print(dft(v))
print('relative error:')
print(abs(fft(v)-dft(v))/abs(dft(v)))
这给出了以下输出:
-------------------------n=2-------------------------
v =
[ 0.+0.j 1.+0.j]
fft(v) =
[ 1. +0.00000000e+00j -1. -1.22464680e-16j]
dft(v) =
[ 1. +0.00000000e+00j -1. -1.22464680e-16j]
relative error:
[ 0. 0.]
-------------------------n=4-------------------------
v =
[ 0.+0.j 1.+0.j 2.+0.j 3.+0.j]
fft(v) =
[ 6. +0.00000000e+00j -2. +2.00000000e+00j -2. -4.89858720e-16j
-2. -2.00000000e+00j]
dft(v) =
[ 6. +0.00000000e+00j -2. +2.00000000e+00j -2. -7.34788079e-16j
-2. -2.00000000e+00j]
relative error:
[ 0.00000000e+00 0.00000000e+00 1.22464680e-16 3.51083347e-16]
-------------------------n=8-------------------------
v =
[ 0.+0.j 1.+0.j 2.+0.j 3.+0.j 4.+0.j 5.+0.j 6.+0.j 7.+0.j]
fft(v) =
[ 28. +0.00000000e+00j -4. +9.65685425e+00j -4. +4.00000000e+00j
-4. +1.65685425e+00j -4. -7.10542736e-15j -4. -1.65685425e+00j
-4. -4.00000000e+00j -4. -9.65685425e+00j]
dft(v) =
[ 28. +0.00000000e+00j -4. +9.65685425e+00j -4. +4.00000000e+00j
-4. +1.65685425e+00j -4. -3.42901104e-15j -4. -1.65685425e+00j
-4. -4.00000000e+00j -4. -9.65685425e+00j]
relative error:
[ 0.00000000e+00 6.79782332e-16 7.40611132e-16 1.85764404e-15
9.19104080e-16 3.48892999e-15 3.92837008e-15 1.35490975e-15]
-------------------------n=16-------------------------
v =
[ 0.+0.j 1.+0.j 2.+0.j 3.+0.j 4.+0.j 5.+0.j 6.+0.j 7.+0.j
8.+0.j 9.+0.j 10.+0.j 11.+0.j 12.+0.j 13.+0.j 14.+0.j 15.+0.j]
fft(v) =
[ 120. +0.00000000e+00j -8. +4.02187159e+01j -8. +1.93137085e+01j
-8. +1.19728461e+01j -8. +8.00000000e+00j -8. +5.34542910e+00j
-8. +3.31370850e+00j -8. +1.59129894e+00j -8. +2.84217094e-14j
-8. -1.59129894e+00j -8. -3.31370850e+00j -8. -5.34542910e+00j
-8. -8.00000000e+00j -8. -1.19728461e+01j -8. -1.93137085e+01j
-8. -4.02187159e+01j]
dft(v) =
[ 120. +0.00000000e+00j -8. +4.02187159e+01j -8. +1.93137085e+01j
-8. +1.19728461e+01j -8. +8.00000000e+00j -8. +5.34542910e+00j
-8. +3.31370850e+00j -8. +1.59129894e+00j -8. -6.08810394e-14j
-8. -1.59129894e+00j -8. -3.31370850e+00j -8. -5.34542910e+00j
-8. -8.00000000e+00j -8. -1.19728461e+01j -8. -1.93137085e+01j
-8. -4.02187159e+01j]
relative error:
[ 0.00000000e+00 1.09588741e-15 1.45449990e-15 6.36716793e-15
8.53211992e-15 9.06818284e-15 1.30922044e-14 5.40949529e-15
1.11628436e-14 1.23698141e-14 1.50430426e-14 3.02428869e-14
2.84810617e-14 1.16373983e-14 1.10680934e-14 3.92841628e-15]
这是一个很好的挑战 -- 我学到了很多东西!您可以在线验证代码的结果,例如here.
我正在尝试基于 dft(离散傅立叶变换)矩阵实现 fft(快速傅立叶变换)factorization.In 以下代码,包括 fft 和直接方法(即:直接乘以 dft 矩阵与 v) 的实施是为了测试我实施 fft 的有效性。
import numpy as n
import cmath, math
import matplotlib.pyplot as plt
v=n.array([1,-1,2,-3])
w=v
N=len(v)
t=[0]*N
M=n.zeros((N,N),dtype=complex)
z=n.exp(2j*math.pi/N)
for a in range(N):
for b in range(N):
M[a][b]=n.exp(2j*math.pi*a*b/N)
print (n.dot(v,M))
plt.plot(n.dot(v,M))
def f(x):
x=n.concatenate([x[::2],x[1::2]])
return x
while (w!=f(v)).any():
v=f(v)
print(v)
a=2
while a<=N:
for k in range(N/a):
for y in range(a/2):
t[y]=v[a*k+y]
for i in range(a/2):
v[a*k+i]+=v[a*k+i+a/2]*(z**i)
v[a*k+i+a/2]=t[i]-v[a*k+i+a/2]*(z**i)
a*=2
print(v)
plt.plot(v)
plt.show()
我用很多 v 值尝试过,有时这两种方法的输出产生完全相同的结果,但有时它们彼此接近但不完全相同。经过几次测试后,他们还没有远离彼此,每个测试都有不同的 v 值。
有什么我遗漏的导致代码不精确的地方吗?
编辑: 请注意,该代码是为 Python 2 设计的(因为隐式整数除法)。
看来问题不在算法上,而是在v的声明上(感谢@kazemakase)。尝试
v=n.array([1,-1,2,-3], dtype=complex)
相反。至少对我来说,曲线然后出现在彼此之上:
编辑
这真是一段旅程。我无法弄清楚您的代码有什么问题,但看起来有几个错误,包括 dft 和 fft。最后,我根据 [this document] (http://www.cs.cmu.edu/afs/andrew/scs/cs/15-463/2001/pub/www/notes/fourier/fourier.pdf) (pages 6 -- 9 hold all the information you need). Maybe you can go through the algorithm and figure out where your problems lie. The algorithm for the bit reversal can be found in this answer (or alternatively in this one ) 编写了自己的 fft 版本。我测试了不同长度的线性向量的代码——如果您发现任何错误,请告诉我。
import numpy as np
import cmath
def bit_reverse(x,n):
"""
Reverse the last n bits of x
"""
##from
##formstr = '{{:0{}b}}'.format(n)
##return int(formstr.format(x)[::-1],2)
##from
return sum(1<<(n-1-i) for i in range(n) if x>>i&1)
def permute_vector(v):
"""
Permute vector v such that the indices of the result
correspond to the bit-reversed indices of the original.
Returns the permuted input vector and the number of bits used.
"""
##check that len(v) == 2**n
##and at the same time find permutation length:
L = len(v)
comp = 1
bits = 0
while comp<L:
comp *= 2
bits += 1
if comp != L:
raise ValueError('permute_vector: wrong length of v -- must be 2**n')
rindices = [bit_reverse(i,bits)for i in range(L)]
return v[rindices],bits
def dft(v):
N = v.shape[0]
a,b = np.meshgrid(
np.linspace(0,N-1,N,dtype=np.complex128),
np.linspace(0,N-1,N,dtype=np.complex128),
)
M = np.exp((-2j*np.pi*a*b)/N)
return np.dot(M,v)
def fft(v):
w,bits = permute_vector(v)
N = w.shape[0]
z=np.exp(np.array(-2j,dtype=np.complex128)*np.pi/N)
##starting fft
for i in range(bits):
dist = 2**i ##distance between 'exchange pairs'
group = dist*2 ##size of sub-groups
for start in range(0,N,group):
for offset in range(group//2):
pos1 = start+offset
pos2 = pos1+dist
alpha1 = z**((pos1*N//group)%N)
alpha2 = z**((pos2*N//group)%N)
w[pos1],w[pos2] = w[pos1]+alpha1*w[pos2],w[pos1]+alpha2*w[pos2]
return w
if __name__ == '__main__':
#test the fft
for n in [2**i for i in range(1,5)]:
print('-'*25+'n={}'.format(n)+'-'*25)
v = np.linspace(0,n-1,n, dtype=np.complex128)
print('v = ')
print(v)
print('fft(v) = ')
print(fft(v))
print('dft(v) = ')
print(dft(v))
print('relative error:')
print(abs(fft(v)-dft(v))/abs(dft(v)))
这给出了以下输出:
-------------------------n=2-------------------------
v =
[ 0.+0.j 1.+0.j]
fft(v) =
[ 1. +0.00000000e+00j -1. -1.22464680e-16j]
dft(v) =
[ 1. +0.00000000e+00j -1. -1.22464680e-16j]
relative error:
[ 0. 0.]
-------------------------n=4-------------------------
v =
[ 0.+0.j 1.+0.j 2.+0.j 3.+0.j]
fft(v) =
[ 6. +0.00000000e+00j -2. +2.00000000e+00j -2. -4.89858720e-16j
-2. -2.00000000e+00j]
dft(v) =
[ 6. +0.00000000e+00j -2. +2.00000000e+00j -2. -7.34788079e-16j
-2. -2.00000000e+00j]
relative error:
[ 0.00000000e+00 0.00000000e+00 1.22464680e-16 3.51083347e-16]
-------------------------n=8-------------------------
v =
[ 0.+0.j 1.+0.j 2.+0.j 3.+0.j 4.+0.j 5.+0.j 6.+0.j 7.+0.j]
fft(v) =
[ 28. +0.00000000e+00j -4. +9.65685425e+00j -4. +4.00000000e+00j
-4. +1.65685425e+00j -4. -7.10542736e-15j -4. -1.65685425e+00j
-4. -4.00000000e+00j -4. -9.65685425e+00j]
dft(v) =
[ 28. +0.00000000e+00j -4. +9.65685425e+00j -4. +4.00000000e+00j
-4. +1.65685425e+00j -4. -3.42901104e-15j -4. -1.65685425e+00j
-4. -4.00000000e+00j -4. -9.65685425e+00j]
relative error:
[ 0.00000000e+00 6.79782332e-16 7.40611132e-16 1.85764404e-15
9.19104080e-16 3.48892999e-15 3.92837008e-15 1.35490975e-15]
-------------------------n=16-------------------------
v =
[ 0.+0.j 1.+0.j 2.+0.j 3.+0.j 4.+0.j 5.+0.j 6.+0.j 7.+0.j
8.+0.j 9.+0.j 10.+0.j 11.+0.j 12.+0.j 13.+0.j 14.+0.j 15.+0.j]
fft(v) =
[ 120. +0.00000000e+00j -8. +4.02187159e+01j -8. +1.93137085e+01j
-8. +1.19728461e+01j -8. +8.00000000e+00j -8. +5.34542910e+00j
-8. +3.31370850e+00j -8. +1.59129894e+00j -8. +2.84217094e-14j
-8. -1.59129894e+00j -8. -3.31370850e+00j -8. -5.34542910e+00j
-8. -8.00000000e+00j -8. -1.19728461e+01j -8. -1.93137085e+01j
-8. -4.02187159e+01j]
dft(v) =
[ 120. +0.00000000e+00j -8. +4.02187159e+01j -8. +1.93137085e+01j
-8. +1.19728461e+01j -8. +8.00000000e+00j -8. +5.34542910e+00j
-8. +3.31370850e+00j -8. +1.59129894e+00j -8. -6.08810394e-14j
-8. -1.59129894e+00j -8. -3.31370850e+00j -8. -5.34542910e+00j
-8. -8.00000000e+00j -8. -1.19728461e+01j -8. -1.93137085e+01j
-8. -4.02187159e+01j]
relative error:
[ 0.00000000e+00 1.09588741e-15 1.45449990e-15 6.36716793e-15
8.53211992e-15 9.06818284e-15 1.30922044e-14 5.40949529e-15
1.11628436e-14 1.23698141e-14 1.50430426e-14 3.02428869e-14
2.84810617e-14 1.16373983e-14 1.10680934e-14 3.92841628e-15]
这是一个很好的挑战 -- 我学到了很多东西!您可以在线验证代码的结果,例如here.