错误的结果袖带 3D 到位

Wrong results cufft 3D in-place

我写这篇文章是因为我在原位袖带 3D 转换方面遇到了问题,而我对异地版本没有任何问题。我尝试遵循 Robert Crovella 的回答 here,但在进行 FFT+IFT 时我没有获得正确的结果。 这是我的代码:

#include <stdio.h>
#include <stdlib.h>
#include <cuda_runtime.h>
#include <complex.h>
#include <cuComplex.h>
#include <cufft.h>

// Main function
int main(int argc, char **argv){
    int N = 4; 
    double *in = NULL, *d_in = NULL;
    cuDoubleComplex *out = NULL, *d_out = NULL;
    cufftHandle plan_r2c, plan_c2r;

    unsigned int out_mem_size = sizeof(cuDoubleComplex) * N*N*(N/2 + 1);
    unsigned int in_mem_size = out_mem_size;

    in  = (double *) malloc (in_mem_size);
    out  = (cuDoubleComplex *)in;

    cudaMalloc((void **)&d_in, in_mem_size);
    d_out = (cuDoubleComplex *)d_in;

    cufftPlan3d(&plan_r2c, N, N, N, CUFFT_D2Z);
    cufftPlan3d(&plan_c2r, N, N, N, CUFFT_Z2D);

    memset(in, 0, in_mem_size);
    unsigned int idx;
    for (int z = 0; z < N; z++){
        for (int y = 0; y < N; y++){
            for (int x = 0; x < N; x++){
                idx = z + N * ( y + x * N);
                in[idx] = idx;
            }
        }
    }
    printf("\nStart: \n");
    for (int z = 0; z < N; z++){
        printf("plane = %d ----------------------------\n", z);
        for (int x = 0; x < N; x++){
            for (int y = 0; y < N; y++){
                idx = z + N * ( y + x * N);
                printf("%.3f \t", in[idx]);
            }
            printf("\n");
        }
    }
    cudaMemcpy(d_in, in, in_mem_size, cudaMemcpyHostToDevice);

    cufftExecD2Z(plan_r2c, (cufftDoubleReal *)d_in, (cufftDoubleComplex *)d_out);
    cufftExecZ2D(plan_c2r, (cufftDoubleComplex *)d_out, (cufftDoubleReal *)d_in);

    memset(in, 0, in_mem_size);
    CU_ERR_CHECK( cudaMemcpy(in, d_in, in_mem_size, cudaMemcpyDeviceToHost) );

    printf("\nAfter FFT+IFT: \n");
    for (int z = 0; z < N; z++){
        printf("plane = %d ----------------------------\n", z);
        for (int x = 0; x < N; x++){
            for (int y = 0; y < N; y++){
                idx = z + N * ( y + x * N);
                // Normalisation
                in[idx] /= (N*N*N);
                printf("%.3f \t", in[idx]);
            }
            printf("\n");
        }
    }

    return 0;
}

程序输出如下数据:

起始文件

平面=0----------------------------

0.000 4.000 8.000 12.000
16.000 20.000 24.000 28.000
32.000 36.000 40.000 44.000
48.000 52.000 56.000 60.000

平面=1----------------------------

1.000 5.000 9.000 13.000
17.000 21.000 25.000 29.000
33.000 37.000 41.000 45.000
49.000 53.000 57.000 61.000

平面=2----------------------------

2.000 6.000 10.000 14.000
18.000 22.000 26.000 30.000
34.000 38.000 42.000 46.000
50.000 54.000 58.000 62.000

平面= 3 --------------------------

3.000 7.000 11.000 15.000
19.000 23.000 27.000 31.000
35.000 39.000 43.000 47.000
51.000 55.000 59.000 63.000

FFT+IFT后

平面=0----------------------------

-0.000 -0.344 8.000 12.000
-0.031 20.000 24.000 -0.031
32.000 36.000 0.031 44.000
48.000 -0.094 56.000 60.000

平面=1----------------------------

1.000 -0.000 9.000 13.000
-0.000 21.000 25.000 0.125
33.000 37.000 0.000 45.000
49.000 0.000 57.000 61.000

平面=2----------------------------

2.000 6.000 -0.000 14.000
18.000 0.000 26.000 30.000
0.000 38.000 42.000 -0.000
50.000 54.000 -0.000 62.000

平面= 3 --------------------------

3.000 7.000 0.031 15.000
19.000 -0.031 27.000 31.000
-0.031 39.000 43.000 0.031
51.000 55.000 0.031 63.000

我什至试过这样填充数据:

// With padding
    unsigned int idx;
    for (int x = 0; x < N; x++){
        for (int y = 0; y < N; y++){
            for (int z = 0; z < 2*(N/2+1); z++){
                idx = z + N * ( y + x * N);
                if (z < 4) in[idx] = idx;
                else in[idx] = 0;
            }
        }
    }

我做错了什么?

如您所知,如果您使用默认的 CUFFT_COMPATIBILITY_FFTW_PADDING 兼容模式,则需要填充。为了使您的代码正常工作,您可以使用 cufftSetCompatibilityMode() 来设置 CUFFT_COMPATIBILITY_NATIVE。但是,此模式在当前版本的 CUDA 中被标记为已弃用。

所以我推荐使用默认的兼容模式,使用padding。您尝试实施填充是错误的。计算 3 维 x、y、z 的线性索引的公式是 idx = z + Nz*(y + Ny*x),其中 z 是最快的 运行 索引。包括填充在内的 z 维度的大小 NzNz = (N/2+1)*2。那么,数组的正确初始化是:

unsigned int idx;
for (int z = 0; z < N; z++){
    for (int y = 0; y < N; y++){
        for (int x = 0; x < N; x++){
            idx = z + (N/2+1)*2 * ( y + x * N);
            in[idx] = idx;
        }
    }
}

相应的打印循环。