PyTorch DataLoader 洗牌
PyTorch DataLoader shuffle
我做了一个实验,但没有得到预期的结果。
对于第一部分,我使用
trainloader = torch.utils.data.DataLoader(trainset, batch_size=128,
shuffle=False, num_workers=0)
在训练我的模型之前,我将 trainloader.dataset.targets
保存到变量 a
,并将 trainloader.dataset.data
保存到变量 b
。然后,我使用 trainloader
训练模型。
训练完成后,我将trainloader.dataset.targets
保存到变量c
,将trainloader.dataset.data
保存到变量d
。最后,我检查了 a == c
和 b == d
,它们都给出了 True
,这是预期的,因为 DataLoader
的 shuffle 参数是 False
.
对于第二部分,我使用
trainloader = torch.utils.data.DataLoader(trainset, batch_size=128,
shuffle=True, num_workers=0)
在训练我的模型之前,我将 trainloader.dataset.targets
保存到变量 e
,并将 trainloader.dataset.data
保存到变量 f
。然后,我使用 trainloader
训练模型。训练结束后,我将trainloader.dataset.targets
保存到变量g
,将trainloader.dataset.data
保存到变量h
。我希望 e == g
和 f == h
都是 False
,因为 shuffle=True
,但他们又给出了 True
。我在 DataLoader
class 的定义中遗漏了什么?
我相信直接存储在 trainloader.dataset.data 或 .target 中的数据不会被打乱,只有当 DataLoader 作为生成器或迭代器被调用时数据才会被打乱
你可以通过 next(iter(trainloader)) 不洗牌和洗牌几次来检查它,它们应该给出不同的结果
import torch
import torchvision
transform = torchvision.transforms.Compose([
torchvision.transforms.ToTensor(),
])
MNIST_dataset = torchvision.datasets.MNIST('~/Desktop/intern/',download = True, train = False,
transform = transform)
dataLoader = torch.utils.data.DataLoader(MNIST_dataset,
batch_size = 128,
shuffle = False,
num_workers = 10)
target = dataLoader.dataset.targets
MNIST_dataset = torchvision.datasets.MNIST('~/Desktop/intern/',download = True, train = False,
transform = transform)
dataLoader_shuffled= torch.utils.data.DataLoader(MNIST_dataset,
batch_size = 128,
shuffle = True,
num_workers = 10)
target_shuffled = dataLoader_shuffled.dataset.targets
print(target == target_shuffled)
_, target = next(iter(dataLoader));
_, target_shuffled = next(iter(dataLoader_shuffled))
print(target == target_shuffled)
这将给出:
tensor([True, True, True, ..., True, True, True])
tensor([False, False, False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False, False, True,
False, False, False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False, False, False,
False, False, False, False, True, False, False, False, False, False,
False, True, False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False, False, False,
False, False, False, False, True, True, False, False, False, False,
False, False, False, False, False, False, False, False, False, False,
False, False, False, False, False, True, False, False, True, False,
False, False, False, False, False, False, False, False])
然而,data 和 target 中存储的数据和标签是一个固定列表,由于您试图直接访问它,因此它们不会被打乱。
我在使用数据集 class 加载数据时遇到了类似的问题。我停止使用数据集 class 加载数据,而是使用以下代码,它对我来说工作正常
X = torch.from_numpy(X)
y = torch.from_numpy(y)
train_data = torch.utils.data.TensorDataset(X, y)
train_loader = torch.utils.data.DataLoader(train_data, batch_size=32, shuffle=True)
其中 X 和 y 是来自 csv 文件的 numpy 数组。
我做了一个实验,但没有得到预期的结果。
对于第一部分,我使用
trainloader = torch.utils.data.DataLoader(trainset, batch_size=128,
shuffle=False, num_workers=0)
在训练我的模型之前,我将 trainloader.dataset.targets
保存到变量 a
,并将 trainloader.dataset.data
保存到变量 b
。然后,我使用 trainloader
训练模型。
训练完成后,我将trainloader.dataset.targets
保存到变量c
,将trainloader.dataset.data
保存到变量d
。最后,我检查了 a == c
和 b == d
,它们都给出了 True
,这是预期的,因为 DataLoader
的 shuffle 参数是 False
.
对于第二部分,我使用
trainloader = torch.utils.data.DataLoader(trainset, batch_size=128,
shuffle=True, num_workers=0)
在训练我的模型之前,我将 trainloader.dataset.targets
保存到变量 e
,并将 trainloader.dataset.data
保存到变量 f
。然后,我使用 trainloader
训练模型。训练结束后,我将trainloader.dataset.targets
保存到变量g
,将trainloader.dataset.data
保存到变量h
。我希望 e == g
和 f == h
都是 False
,因为 shuffle=True
,但他们又给出了 True
。我在 DataLoader
class 的定义中遗漏了什么?
我相信直接存储在 trainloader.dataset.data 或 .target 中的数据不会被打乱,只有当 DataLoader 作为生成器或迭代器被调用时数据才会被打乱
你可以通过 next(iter(trainloader)) 不洗牌和洗牌几次来检查它,它们应该给出不同的结果
import torch
import torchvision
transform = torchvision.transforms.Compose([
torchvision.transforms.ToTensor(),
])
MNIST_dataset = torchvision.datasets.MNIST('~/Desktop/intern/',download = True, train = False,
transform = transform)
dataLoader = torch.utils.data.DataLoader(MNIST_dataset,
batch_size = 128,
shuffle = False,
num_workers = 10)
target = dataLoader.dataset.targets
MNIST_dataset = torchvision.datasets.MNIST('~/Desktop/intern/',download = True, train = False,
transform = transform)
dataLoader_shuffled= torch.utils.data.DataLoader(MNIST_dataset,
batch_size = 128,
shuffle = True,
num_workers = 10)
target_shuffled = dataLoader_shuffled.dataset.targets
print(target == target_shuffled)
_, target = next(iter(dataLoader));
_, target_shuffled = next(iter(dataLoader_shuffled))
print(target == target_shuffled)
这将给出:
tensor([True, True, True, ..., True, True, True])
tensor([False, False, False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False, False, True,
False, False, False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False, False, False,
False, False, False, False, True, False, False, False, False, False,
False, True, False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False, False, False,
False, False, False, False, True, True, False, False, False, False,
False, False, False, False, False, False, False, False, False, False,
False, False, False, False, False, True, False, False, True, False,
False, False, False, False, False, False, False, False])
然而,data 和 target 中存储的数据和标签是一个固定列表,由于您试图直接访问它,因此它们不会被打乱。
我在使用数据集 class 加载数据时遇到了类似的问题。我停止使用数据集 class 加载数据,而是使用以下代码,它对我来说工作正常
X = torch.from_numpy(X)
y = torch.from_numpy(y)
train_data = torch.utils.data.TensorDataset(X, y)
train_loader = torch.utils.data.DataLoader(train_data, batch_size=32, shuffle=True)
其中 X 和 y 是来自 csv 文件的 numpy 数组。