torchvision.datasets.ImageFolder 给我一个 3x3 的图像网格而不是 1 个图像
torchvision.datasets.ImageFolder is giving me a 3x3 grid of images intead of 1 image
我不明白为什么它会在 3x3 网格中提供 9 张灰色图像,而不是只有一张彩色图像(原始图像不是灰色的,并且具有 RGB 通道)。我在这上面花了 5 个小时。感谢您的帮助。
这是我的代码
test_path = "asl_data/test/" #path to the folder
test_data = torchvision.datasets.ImageFolder(test_path, transform=torchvision.transforms.ToTensor())
def test32():
for x, y in test_data:
print(x.shape)
x = x.reshape(533,800,3)
plt.axis("off")
plt.imshow(x)
plt.show()
plt.axis("off")
plt.imshow(x[:176,:267,:])
break
test32()
经典。
你 reshape
而不是 permute
。
请参阅 this thread 了解两者之间的关键 区别。
修复:
x = x.permute((1, 2, 0))
plt.imshow(x)
一个简单的视觉示例:
x, y = test_data[0] # take one image
x.shape # torch.Size([3, 223, 320])
# see the difference
fig, ax = plt.subplots(1,2)
ax[0].imshow(x.numpy().reshape(223, 320, 3))
ax[0].set_title('Wrong reshape instead of permute')
ax[1].imshow(x.permute((1,2,0)))
ax[1].set_title('correctly permuting')
我不明白为什么它会在 3x3 网格中提供 9 张灰色图像,而不是只有一张彩色图像(原始图像不是灰色的,并且具有 RGB 通道)。我在这上面花了 5 个小时。感谢您的帮助。
这是我的代码
test_path = "asl_data/test/" #path to the folder
test_data = torchvision.datasets.ImageFolder(test_path, transform=torchvision.transforms.ToTensor())
def test32():
for x, y in test_data:
print(x.shape)
x = x.reshape(533,800,3)
plt.axis("off")
plt.imshow(x)
plt.show()
plt.axis("off")
plt.imshow(x[:176,:267,:])
break
test32()
经典。
你 reshape
而不是 permute
。
请参阅 this thread 了解两者之间的关键 区别。
修复:
x = x.permute((1, 2, 0))
plt.imshow(x)
一个简单的视觉示例:
x, y = test_data[0] # take one image
x.shape # torch.Size([3, 223, 320])
# see the difference
fig, ax = plt.subplots(1,2)
ax[0].imshow(x.numpy().reshape(223, 320, 3))
ax[0].set_title('Wrong reshape instead of permute')
ax[1].imshow(x.permute((1,2,0)))
ax[1].set_title('correctly permuting')