在 PyTorch 中显示卷积结果

Display result of convolution in PyTorch

这里是 PyTorch 新手。我写了一个执行以下操作的脚本(下面的代码):加载图像,执行 2D 卷积操作,然后显示输出和输入。

目前我有下图,好像不对。如何正确绘制特征图?

import numpy as np
import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
import matplotlib.pyplot as plt
import imageio
import sys

A = imageio.imread('LiT.png')
# Define how the convolution operation works
conv2 = nn.Conv2d(in_channels=3, out_channels=3, kernel_size=3, stride=1, padding=1)

image_d = torch.FloatTensor(np.asarray(A.reshape(1, 3, A.shape[0] , A.shape[1])))
fc = conv2(image_d)
fc1 = fc.permute(0, 2, 3, 1).reshape([516, 780, 3])

plt.figure(figsize=(16,8))
plt.subplot(1,2,1)
plt.imshow(A)
plt.subplot(1,2,2)
plt.imshow(fc1.data.numpy())

plt.show()

据我了解,问题在于您如何使用重塑来排列图像中的通道位置。相反,应该使用 'np.transpose or tensor.permute。使用手电筒进行排列:

image_d  = torch.FloatTensor(np.asarray(A)).unsqueeze(0).permute(0,3,1,2)

或者,如果我们想在 numpy 中处理排列部分:

image_d = np.transpose(np.asarray(A), (2,0,1))
image_d = torch.FloatTensor(image_d).unsqueeze(0)

你的代码的问题是这一行

image_d = torch.FloatTensor(np.asarray(A.reshape(1, 3, A.shape[0] , A.shape[1])))

您不能只重塑需要转置通道的图像。作为对未来的评论,如果你得到像你那样的条纹结果,很可能是一些 permutation/transposition 或不正确的重塑操作。

除此之外,我还将输入图像缩放到 [0, 1] 以正确显示。以下是工作代码:

import numpy as np
import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
import matplotlib.pyplot as plt
import imageio
import sys

A = imageio.imread('LiT.png')
# Define how the convolution operation works
conv2 = nn.Conv2d(in_channels=3, out_channels=3, kernel_size=3, stride=1, padding=1)

# from [H, W, C] to [C, H, W]
transposed_image = A.transpose((2, 0, 1))
# add batch dim
transposed_image = np.expand_dims(transposed_image, 0)

image_d = torch.FloatTensor(transposed_image)
fc = conv2(image_d)
fc1 = fc.permute(0, 2, 3, 1)[0]
result = fc1.data.numpy()
max_ = np.max(result)
min_ = np.min(result)
result -= min_
result /= max_

plt.figure(figsize=(16,8))
plt.subplot(1,2,1)
plt.imshow(A)
plt.subplot(1,2,2)
plt.imshow(result)

plt.show()