如何从 .pt 文件创建 Pytorch 数据集?

How to create a Pytorch Dataset from .pt files?

我已将 MNIST 图像转换为 .pt 文件并保存在 Google 驱动器的文件夹中。我正在 Colab 中编写我的 Pytorch 代码。

我想使用这些文件,并创建一个将这些图像存储为张量的数据集。我该怎么做?

训练期间转换图像花费的时间太长。因此,将它们进行转换并将它们全部保存为 .pt 文件。我只想将它们作为数据集加载回来并在我的模型中使用它们。

您采用的保存图像的方法确实是个好主意。在这种情况下,您可以简单地编写自己的数据集 class 来加载图像。

from torch.utils.data import Dataset, DataLoader
from torch.utils.data.sampler import RandomSampler

class ReaderDataset(Dataset):
    def __init__(self, filename):
        # load the images from file

    def __len__(self):
        # return total dataset size

    def __getitem__(self, index):
        # write your code to return each batch element

然后就可以按如下方式创建Dataloader了。

train_dataset = ReaderDataset(filepath)
train_sampler = RandomSampler(train_dataset)
train_loader = DataLoader(
    train_dataset,
    batch_size=args.batch_size,
    sampler=train_sampler,
    num_workers=args.data_workers,
    collate_fn=batchify,
    pin_memory=args.cuda,
    drop_last=args.parallel
)
# args is a dictionary containing parameters
# batchify is a custom function that prepares each mini-batch