Pytorch DataLoader shuffle=False?

Pytorch DataLoader shuffle=False?

我使用 Pytorch DataLoader 创建了我的“批处理数据”加载器,但我遇到了一些问题。

作为pytorch DataLoader Shuffer的定义。

shuffle (bool, optional) – set to True to have the data reshuffled at every epoch (default: False)

数据将在每个纪元后重新洗牌。 但是,虽然我将 shuffle 设置为 False,但我可能还会在我期望的同一时期的每次迭代中得到 完全不同的批处理

testData = torchvision.datasets.FashionMNIST(
    root="data",
    train=False,
    download=True,
    transform=ToTensor()
)

CurrentFoldTestDataLoader = data.DataLoader(testData, batch_size=32, shuffle=False)
for i in range(1000):
    test_features, test_labels = next(iter(CurrentFoldTestDataLoader))
    print(i,test_labels)

在这里,我在每次迭代中都得到了相同的批次。

0 tensor([9, 2, 1, 1, 6, 1, 4, 6, 5, 7, 4, 5, 7, 3, 4, 1, 2, 4, 8, 0, 2, 5, 7, 9,
        1, 4, 6, 0, 9, 3, 8, 8])
1 tensor([9, 2, 1, 1, 6, 1, 4, 6, 5, 7, 4, 5, 7, 3, 4, 1, 2, 4, 8, 0, 2, 5, 7, 9,
        1, 4, 6, 0, 9, 3, 8, 8])
2 tensor([9, 2, 1, 1, 6, 1, 4, 6, 5, 7, 4, 5, 7, 3, 4, 1, 2, 4, 8, 0, 2, 5, 7, 9,
        1, 4, 6, 0, 9, 3, 8, 8])
3 tensor([9, 2, 1, 1, 6, 1, 4, 6, 5, 7, 4, 5, 7, 3, 4, 1, 2, 4, 8, 0, 2, 5, 7, 9,
        1, 4, 6, 0, 9, 3, 8, 8])
4 tensor([9, 2, 1, 1, 6, 1, 4, 6, 5, 7, 4, 5, 7, 3, 4, 1, 2, 4, 8, 0, 2, 5, 7, 9,
        1, 4, 6, 0, 9, 3, 8, 8])
5 tensor([9, 2, 1, 1, 6, 1, 4, 6, 5, 7, 4, 5, 7, 3, 4, 1, 2, 4, 8, 0, 2, 5, 7, 9,
        1, 4, 6, 0, 9, 3, 8, 8])
6 tensor([9, 2, 1, 1, 6, 1, 4, 6, 5, 7, 4, 5, 7, 3, 4, 1, 2, 4, 8, 0, 2, 5, 7, 9,
        1, 4, 6, 0, 9, 3, 8, 8])
7 tensor([9, 2, 1, 1, 6, 1, 4, 6, 5, 7, 4, 5, 7, 3, 4, 1, 2, 4, 8, 0, 2, 5, 7, 9,
        1, 4, 6, 0, 9, 3, 8, 8])
8 tensor([9, 2, 1, 1, 6, 1, 4, 6, 5, 7, 4, 5, 7, 3, 4, 1, 2, 4, 8, 0, 2, 5, 7, 9,
        1, 4, 6, 0, 9, 3, 8, 8])
9 tensor([9, 2, 1, 1, 6, 1, 4, 6, 5, 7, 4, 5, 7, 3, 4, 1, 2, 4, 8, 0, 2, 5, 7, 9,
        1, 4, 6, 0, 9, 3, 8, 8])
10 tensor([9, 2, 1, 1, 6, 1, 4, 6, 5, 7, 4, 5, 7, 3, 4, 1, 2, 4, 8, 0, 2, 5, 7, 9,
        1, 4, 6, 0, 9, 3, 8, 8])

这是为什么?我对shuffle定义的理解不准确吗?

您的代码的问题是您 re-instantiating for 循环中的每个步骤都是同一个迭代器。使用 shuffle=False 迭代器生成相同的第一批图像。尝试在循环外实例化加载器:

loader = data.DataLoader(testData, batch_size=32, shuffle=False)

for i, data in enumerate(loader):
    test_features, test_labels = data
    print(i, test_labels)