pytorch 4d numpy array 在自定义数据集中应用 transfroms

pytorch 4d numpy array applying transfroms inside custom dataset

在我的自定义数据集中,我想将 transforms.Compose() 应用于 NumPy 数组。

我的图像采用 NumPy 数组格式,形状为 (num_samples, width, height, channels)

如何将以下转换应用到整个 numpy 数组?

img_transform = transforms.Compose([ transforms.Scale((224,224)), transforms.ToTensor(), transforms.Normalize([0.46, 0.48, 0.51], [0.32, 0.32, 0.32]) ])

我的尝试以多个错误结束,因为转换接受的是 PIL 图像而不是 4 维 NumPy 数组。

from torchvision import transforms
import numpy as np
import torch

img_transform = transforms.Compose([
        transforms.Scale((224,224)),
        transforms.ToTensor(),
        transforms.Normalize([0.46, 0.48, 0.51], [0.32, 0.32, 0.32])
    ])

a = np.random.randint(0,256, (299,299,3))
print(a.shape)

img_transform(a)

所有 torchvision 变换都对单个图像进行操作,而不是批量图像,因此不能使用 4D 阵列。

作为 NumPy 数组给出的单个图像,就像在您的代码示例中一样,可以通过将它们转换为 PIL 图像来使用。您可以简单地将 transforms.ToPILImage 添加到转换管道的开头,因为它将张量或 NumPy 数组转换为 PIL 图像。

img_transform = transforms.Compose([
        transforms.ToPILImage(),
        transforms.Resize((224,224)),
        transforms.ToTensor(),
        transforms.Normalize([0.46, 0.48, 0.51], [0.32, 0.32, 0.32])
    ])

注:transforms.Scale is deprecated in favour of transforms.Resize.

在您的示例中,您使用了 np.random.randint,它默认使用 int64 类型,但图像必须是 uint8。 OpenCV 等库 return uint8 arrays when loading an image.

a = np.random.randint(0,256, (299,299,3), dtype=np.uint8)