在张量中间执行tf.signal.fft2d

Perform tf.signal.fft2d in the middle of the tensor

我有一个 tf.Tensor,例如,形状 (31, 6, 6, 3)

我想在形状 6, 6 上执行 tf.signal.fft2d,换句话说,在中间。然而,描述说:

Computes the 2-dimensional discrete Fourier transform over the inner-most 2 dimensions of input

我可以用 for 循环来完成,但我担心它可能非常无效。有没有最快的方法?

结果当然必须具有相同的输出形状。

感谢 I implemented this solution using tf.transpose:

in_pad = tf.transpose(in_pad, perm=[0, 3, 1, 2])
out = tf.signal.fft2d(tf.cast(in_pad, tf.complex64))
out = tf.transpose(out, perm=[0, 2, 3, 1])