使用 Pytorch 在频域中对图像进行上采样

Upsampling images in frequency domain using Pytorch

我正在尝试使用 Pytorch 在频域中对 RGB 图像进行上采样。我使用 this article 作为灰度图像的参考。由于 Pytorch 单独处理通道,我认为色彩空间在这里无关紧要。

本文概述的基本步骤是:

  1. 对图像执行 FFT。

  2. 用零填充 FFT。

  3. 执行反 FFT。

我为此编写了以下代码:

import torch
import cv2
import numpy as np


img = src = cv2.imread('orig.png')
torch_img = torch.from_numpy(img).to(torch.float32).permute(2, 0, 1) / 255.
fft = torch.fft.fft2(torch_img, norm="forward")
fr = fft.real
fi = fft.imag
fr = F.pad(fr, (fft.shape[-1]//2, fft.shape[-1]//2, fft.shape[-2]//2, fft.shape[-2]//2), mode='constant', value=0)
fi = F.pad(fi, (fft.shape[-1]//2, fft.shape[-1]//2, fft.shape[-2]//2, fft.shape[-2]//2), mode='constant', value=0)

fft_hires = torch.complex(fr, fi)
inv = torch.fft.ifft2(fft_hires, norm="forward").real

print(inv.max(), inv.min())
img = (inv.permute(1, 2, 0).detach()).clamp(0, 1)
img = (255 * img).numpy().astype(np.uint8)
cv2.imwrite('hires.png', img)

原图:

放大图像:

另一个值得注意的有趣的事情是执行IFFT后图像像素的最大值和最小值:它们分别是2.2729-1.8376。理想情况下,它们应该是 1.0 和 0.0。

谁能解释一下这里出了什么问题?

DFT 通常的惯例是将第一个样本视为 0Hz 分量。但是你需要在中心有 0Hz 分量才能使填充有意义。大多数 FFT 工具都提供移位功能来循环移位您的结果,以便 0Hz 分量位于中心。在 pytorch 中,您需要在 FFT 之后执行 torch.fft.fftshift 并在进行逆 FFT 之前执行 torch.fft.ifftshift 以将 0Hz 分量放回左上角。

import torch
import torch.nn.functional as F
import cv2
import numpy as np


img = src = cv2.imread('orig.png')
torch_img = torch.from_numpy(img).to(torch.float32).permute(2, 0, 1) / 255.
# note the fftshift
fft = torch.fft.fftshift(torch.fft.fft2(torch_img, norm="forward"))

fr = fft.real
fi = fft.imag
fr = F.pad(fr, (fft.shape[-1]//2, fft.shape[-1]//2, fft.shape[-2]//2, fft.shape[-2]//2), mode='constant', value=0)
fi = F.pad(fi, (fft.shape[-1]//2, fft.shape[-1]//2, fft.shape[-2]//2, fft.shape[-2]//2), mode='constant', value=0)

# note the ifftshift
fft_hires = torch.fft.ifftshift(torch.complex(fr, fi))
inv = torch.fft.ifft2(fft_hires, norm="forward").real

print(inv.max(), inv.min())
img = (inv.permute(1, 2, 0).detach()).clamp(0, 1)
img = (255 * img).numpy().astype(np.uint8)
cv2.imwrite('hires.png', img)

产生以下 hires.png