如何在 pytorch 中自行洗牌?

How to shuffle the batches themselves in pytorch?

如何在打乱批次的同时保持每个批次中的序列不打乱?

受到问题 here 的启发。

[错误答案 - 使用上面的 ]

  1. 创建数据集
dataset = [1, 2, 3, 4, 5, 6, 7, 8, 9] # Realistically use torch.utils.data.Dataset
  1. 创建一个非随机数据加载器
dataloader = DataLoader(dataset, batch_size=64, shuffle=False)
  1. 将数据加载器转换为 list 并使用 randomsample() 函数
import random
dataloader = random.sample(list(dataloader), len(dataloader))

可能有更好的方法可以使用自定义批处理采样器或其他东西来执行此操作,但对我来说太混乱了,所以上面的方法似乎很有效。

虽然这不是您问题的直接答案。我想解决您自己发布的答案的问题。在我看来,执行以下操作是一个非常的坏主意:

dataloader = random.sample(list(dataloader), len(dataloader))

这首先违背了创建数据集和数据加载器的全部目的。因为一旦你调用 list(dataloader) 你最终会将你的数据集编译成一个张量列表。换句话说,它将为数据集中的每个索引调用 __getitem__。数据加载器旨在逐批加载数据(或更多,具体取决于工作人员的数量),避免一次将整个数据集加载到内存中。

这在处理需要从文件系统加载图像的图像时更为重要。这很关键,我相信你根本不应该这样做。

看看这里,有一个虚拟数据集:

class DS(Dataset):
    def __getitem__(self, _):
        return torch.rand(100)

    def __len__(self):
        return 10000

dl = DataLoader(DS(), batch_size=16)
x = list(dl)

此处 x 将包含 10,000 个大小为 100 的张量,您的计算机可以处理这些张量。现在想象一下,如果有一个由 10,000 512x512 RGB 图像组成的数据集,您的内存就无法容纳那么多!

此外,我什至没有提到的是数据扩充。这只有在保留数据加载器(即生成器)时才有可能。因此,在使用 list(dataloader).

时,转换是在运行时对输入数据计算的,而不是在 编译时 (如果你愿意的话)

我建议您让 Dataset 为每个项目生成 unshuffled 序列,然后用 [=18] 从中生成 DataLoader =].这感觉比生成 DataLoader 只是为了编译它要自然得多。按预期使用您的数据集 class。它应该是构建每个序列(即数据点)的那个,或者如 @Prune 所说的“单个观察对象”。