在 python 中使用 random_split() 拆分训练集以进行训练和验证

using random_split() in python to split the Trainset to train and validation

train_dataset = torchvision.datasets.FashionMNIST(data_dir, train=True, download=True)
test_dataset  = torchvision.datasets.FashionMNIST(data_dir, train=False, download=True)

使用上面两行我加载了 Mnist 数据集,然后使用下面的代码行将它们传输到 Tensor 和 Dataloader

tr =transforms.Compose([transforms.ToTensor(),])
train_dataset.transform = tr
test_dataset.transform = tr
train_dataloader = DataLoader(train_dataset, batch_size=256, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=256, shuffle=False)

然后通过使用 for 循环,例如下面的循环,我遍历数据并在 pytorch 中训练模型。

for i in train_dataloader:

但是当我使用 random_split 将训练数据分成两部分时,我使用 for 循环

时出错
train_dataset, val_dataset = random_split(train_dataset, (50000, 10000))

train_dataset.transform = tr
test_dataset.transform = tr
val_dataset.transform = tr

train_dataloader = DataLoader(train_dataset, batch_size=256, shuffle=True)
validation_dataloader = DataLoader(val_dataset, batch_size=256, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=256, shuffle=False)

错误是:

default_collate: batch must contain tensors, numpy arrays, numbers, dicts or lists; found <class 'PIL.Image.Image'>

如何解决问题?

您应该将 transform 直接传递给 FashionMNIST 数据集的构造函数。

train_dataset = torchvision.datasets.FashionMNIST(data_dir, train=True, download=True, transform=tr)
test_dataset  = torchvision.datasets.FashionMNIST(data_dir, train=False, download=True, transform=tr)