Pytorch 的 dataloader shuffle 什么时候发生?

When does dataloader shuffle happen for Pytorch?

我已经多次使用 pytorch 数据加载器的 shuffle 选项。但是我想知道这种洗牌是什么时候发生的,是否是在迭代过程中动态执行的。以下面的代码为例:

namesDataset = NamesDataset()
namesTrainLoader = DataLoader(namesDataset, batch_size=16, shuffle=True)
for batch_data in namesTrainLoader:
    print(batch_data)

当我们定义"namesTrainLoader"时,是否意味着洗牌结束,接下来的迭代将基于固定的数据顺序?定义namesTrainLoader后for循环会不会有随机性?

我试图用一些特殊值替换 "batch_data" 的一半:

for batch_data in namesTrainLoader:
    batch_data[:8] = special_val
    pre = model(batch_data)

假设会有无限个 epoches,"model" 最终会看到 "namesTrainLoader" 中的所有数据吗?或者"namesTrainLoader"一半的数据居然丢给了"model"?

您可以查看 PyTorch 的实现 torch.utils.data.DataLoader here

如果您指定 shuffle=True,将使用 torch.utils.data.RandomSampler(否则为 SequentialSampler)。

创建 DataLoader 的实例时,不会打乱任何内容,它只会实例化对象的必要私有成员和其他类似设置。

当你在迭代期间发出特殊的 __iter__ 方法时,就像你的情况一样,返回一个名为 _SingleProcessDataLoader(self) 的特殊对象,它是一个数据生成器(可能是批处理,洗牌等,假设你不t 使用多处理)。

要找到所有与私有和助手相关的方法有一点困难,但它基本上做的是使用底层 sampler 获取索引,这些索引用于从 torch.utils.data.Dataset.

采样器 运行 直到耗尽并重复该过程(通常是一个时期)。

Will there be any randomness in the for loop after namesTrainLoader was defined?

在每个 cycle/epoch RandomSampler 开始打乱索引 ,所以是的,它会在每个纪元之前随机化(当 __iter__ 被调用并返回新的 _SingleProcessDataLoader(self)),这可以无限期地完成。

[...] will "model" eventually see all the data in "namesTrainLoader"?

是的,它很可能最终会看到所有数据点

改组发生在创建迭代器时。对于 for 循环,这发生在 for 循环开始之前。

您可以手动创建迭代器:

# Iterator gets created, the data has been shuffled at this point.
data_iterator = iter(namesTrainLoader)

默认情况下,数据加载器使用 torch.utils.data.RandomSampler if you set shuffle=True (without providing your own sampler). Its implementation is very straight forward and you can see where the data is shuffled when the iterator is created by looking at the RandomSampler.__iter__ 方法:

def __iter__(self):
    n = len(self.data_source)
    if self.replacement:
        return iter(torch.randint(high=n, size=(self.num_samples,), dtype=torch.int64).tolist())
    return iter(torch.randperm(n).tolist())

return 语句是重要的部分,其中发生了改组。它只是创建索引的随机排列。

这意味着每次完全使用迭代器时您都会看到整个数据集,只是每次的顺序不同。因此没有数据丢失(不包括 drop_last=True 的情况)并且您的模型将在每个时期看到所有数据。