PyTorch 数据集:将整个数据集转换为 NumPy
PyTorch Datasets: Converting entire Dataset to NumPy
我正在尝试将 Torchvision MNIST 训练和测试数据集转换为 NumPy 数组,但找不到实际执行转换的文档。
我的目标是获取整个数据集并将其转换为单个 NumPy 数组,最好不要遍历整个数据集。
我查看了 ,但它没有解决我的问题。
所以我的问题是,利用 torch.utils.data.DataLoader
,我将如何将数据集 (train/test) 转换为两个 NumPy 数组,以便所有示例都存在?
注意:我暂时将批量大小保留为默认值 1;我可以将它设置为 60,000 用于训练,10,000 用于测试,但我不想使用那种神奇的数字。
谢谢。
如果我没理解错的话,你想获得 MNIST 图像的整个训练数据集(总共 60000 张图像,每张图像大小为 1x28x28 数组,其中 1 个用于颜色通道)作为大小为 (60000, 1 , 28, 28)?
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
# Transform to normalized Tensors
transform = transforms.Compose([transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))])
train_dataset = datasets.MNIST('./MNIST/', train=True, transform=transform, download=True)
# test_dataset = datasets.MNIST('./MNIST/', train=False, transform=transform, download=True)
train_loader = DataLoader(train_dataset, batch_size=len(train_dataset))
# test_loader = DataLoader(test_dataset, batch_size=len(test_dataset))
train_dataset_array = next(iter(train_loader))[0].numpy()
# test_dataset_array = next(iter(test_loader))[0].numpy()
这是结果:
>>> train_dataset_array
array([[[[-0.42421296, -0.42421296, -0.42421296, ..., -0.42421296,
-0.42421296, -0.42421296],
[-0.42421296, -0.42421296, -0.42421296, ..., -0.42421296,
-0.42421296, -0.42421296],
[-0.42421296, -0.42421296, -0.42421296, ..., -0.42421296,
-0.42421296, -0.42421296],
...,
[-0.42421296, -0.42421296, -0.42421296, ..., -0.42421296,
-0.42421296, -0.42421296],
[-0.42421296, -0.42421296, -0.42421296, ..., -0.42421296,
-0.42421296, -0.42421296],
[-0.42421296, -0.42421296, -0.42421296, ..., -0.42421296,
-0.42421296, -0.42421296]]],
[[[-0.42421296, -0.42421296, -0.42421296, ..., -0.42421296,
-0.42421296, -0.42421296],
[-0.42421296, -0.42421296, -0.42421296, ..., -0.42421296,
-0.42421296, -0.42421296],
[-0.42421296, -0.42421296, -0.42421296, ..., -0.42421296,
-0.42421296, -0.42421296],
...,
[-0.42421296, -0.42421296, -0.42421296, ..., -0.42421296,
-0.42421296, -0.42421296],
[-0.42421296, -0.42421296, -0.42421296, ..., -0.42421296,
-0.42421296, -0.42421296],
[-0.42421296, -0.42421296, -0.42421296, ..., -0.42421296,
-0.42421296, -0.42421296]]],
[[[-0.42421296, -0.42421296, -0.42421296, ..., -0.42421296,
-0.42421296, -0.42421296],
[-0.42421296, -0.42421296, -0.42421296, ..., -0.42421296,
-0.42421296, -0.42421296],
[-0.42421296, -0.42421296, -0.42421296, ..., -0.42421296,
-0.42421296, -0.42421296],
...,
[-0.42421296, -0.42421296, -0.42421296, ..., -0.42421296,
-0.42421296, -0.42421296],
[-0.42421296, -0.42421296, -0.42421296, ..., -0.42421296,
-0.42421296, -0.42421296],
[-0.42421296, -0.42421296, -0.42421296, ..., -0.42421296,
-0.42421296, -0.42421296]]],
...,
[[[-0.42421296, -0.42421296, -0.42421296, ..., -0.42421296,
-0.42421296, -0.42421296],
[-0.42421296, -0.42421296, -0.42421296, ..., -0.42421296,
-0.42421296, -0.42421296],
[-0.42421296, -0.42421296, -0.42421296, ..., -0.42421296,
-0.42421296, -0.42421296],
...,
[-0.42421296, -0.42421296, -0.42421296, ..., -0.42421296,
-0.42421296, -0.42421296],
[-0.42421296, -0.42421296, -0.42421296, ..., -0.42421296,
-0.42421296, -0.42421296],
[-0.42421296, -0.42421296, -0.42421296, ..., -0.42421296,
-0.42421296, -0.42421296]]],
[[[-0.42421296, -0.42421296, -0.42421296, ..., -0.42421296,
-0.42421296, -0.42421296],
[-0.42421296, -0.42421296, -0.42421296, ..., -0.42421296,
-0.42421296, -0.42421296],
[-0.42421296, -0.42421296, -0.42421296, ..., -0.42421296,
-0.42421296, -0.42421296],
...,
[-0.42421296, -0.42421296, -0.42421296, ..., -0.42421296,
-0.42421296, -0.42421296],
[-0.42421296, -0.42421296, -0.42421296, ..., -0.42421296,
-0.42421296, -0.42421296],
[-0.42421296, -0.42421296, -0.42421296, ..., -0.42421296,
-0.42421296, -0.42421296]]],
[[[-0.42421296, -0.42421296, -0.42421296, ..., -0.42421296,
-0.42421296, -0.42421296],
[-0.42421296, -0.42421296, -0.42421296, ..., -0.42421296,
-0.42421296, -0.42421296],
[-0.42421296, -0.42421296, -0.42421296, ..., -0.42421296,
-0.42421296, -0.42421296],
...,
[-0.42421296, -0.42421296, -0.42421296, ..., -0.42421296,
-0.42421296, -0.42421296],
[-0.42421296, -0.42421296, -0.42421296, ..., -0.42421296,
-0.42421296, -0.42421296],
[-0.42421296, -0.42421296, -0.42421296, ..., -0.42421296,
-0.42421296, -0.42421296]]]], dtype=float32)
编辑:您也可以通过next(iter(train_loader))[1].numpy()
获取标签。或者,您可以使用 train_dataset.data.numpy()
和 train_dataset.targets.numpy()
,但请注意,数据不会像使用数据加载器时那样被 transform
转换。
此任务不需要使用 torch.utils.data.DataLoader
。
from torchvision import datasets, transforms
train_set = datasets.MNIST('./data', train=True, download=True)
test_set = datasets.MNIST('./data', train=False, download=True)
train_set_array = train_set.data.numpy()
test_set_array = test_set.data.numpy()
请注意,在这种情况下目标被排除在外。
我正在尝试将 Torchvision MNIST 训练和测试数据集转换为 NumPy 数组,但找不到实际执行转换的文档。
我的目标是获取整个数据集并将其转换为单个 NumPy 数组,最好不要遍历整个数据集。
我查看了
所以我的问题是,利用 torch.utils.data.DataLoader
,我将如何将数据集 (train/test) 转换为两个 NumPy 数组,以便所有示例都存在?
注意:我暂时将批量大小保留为默认值 1;我可以将它设置为 60,000 用于训练,10,000 用于测试,但我不想使用那种神奇的数字。
谢谢。
如果我没理解错的话,你想获得 MNIST 图像的整个训练数据集(总共 60000 张图像,每张图像大小为 1x28x28 数组,其中 1 个用于颜色通道)作为大小为 (60000, 1 , 28, 28)?
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
# Transform to normalized Tensors
transform = transforms.Compose([transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))])
train_dataset = datasets.MNIST('./MNIST/', train=True, transform=transform, download=True)
# test_dataset = datasets.MNIST('./MNIST/', train=False, transform=transform, download=True)
train_loader = DataLoader(train_dataset, batch_size=len(train_dataset))
# test_loader = DataLoader(test_dataset, batch_size=len(test_dataset))
train_dataset_array = next(iter(train_loader))[0].numpy()
# test_dataset_array = next(iter(test_loader))[0].numpy()
这是结果:
>>> train_dataset_array
array([[[[-0.42421296, -0.42421296, -0.42421296, ..., -0.42421296,
-0.42421296, -0.42421296],
[-0.42421296, -0.42421296, -0.42421296, ..., -0.42421296,
-0.42421296, -0.42421296],
[-0.42421296, -0.42421296, -0.42421296, ..., -0.42421296,
-0.42421296, -0.42421296],
...,
[-0.42421296, -0.42421296, -0.42421296, ..., -0.42421296,
-0.42421296, -0.42421296],
[-0.42421296, -0.42421296, -0.42421296, ..., -0.42421296,
-0.42421296, -0.42421296],
[-0.42421296, -0.42421296, -0.42421296, ..., -0.42421296,
-0.42421296, -0.42421296]]],
[[[-0.42421296, -0.42421296, -0.42421296, ..., -0.42421296,
-0.42421296, -0.42421296],
[-0.42421296, -0.42421296, -0.42421296, ..., -0.42421296,
-0.42421296, -0.42421296],
[-0.42421296, -0.42421296, -0.42421296, ..., -0.42421296,
-0.42421296, -0.42421296],
...,
[-0.42421296, -0.42421296, -0.42421296, ..., -0.42421296,
-0.42421296, -0.42421296],
[-0.42421296, -0.42421296, -0.42421296, ..., -0.42421296,
-0.42421296, -0.42421296],
[-0.42421296, -0.42421296, -0.42421296, ..., -0.42421296,
-0.42421296, -0.42421296]]],
[[[-0.42421296, -0.42421296, -0.42421296, ..., -0.42421296,
-0.42421296, -0.42421296],
[-0.42421296, -0.42421296, -0.42421296, ..., -0.42421296,
-0.42421296, -0.42421296],
[-0.42421296, -0.42421296, -0.42421296, ..., -0.42421296,
-0.42421296, -0.42421296],
...,
[-0.42421296, -0.42421296, -0.42421296, ..., -0.42421296,
-0.42421296, -0.42421296],
[-0.42421296, -0.42421296, -0.42421296, ..., -0.42421296,
-0.42421296, -0.42421296],
[-0.42421296, -0.42421296, -0.42421296, ..., -0.42421296,
-0.42421296, -0.42421296]]],
...,
[[[-0.42421296, -0.42421296, -0.42421296, ..., -0.42421296,
-0.42421296, -0.42421296],
[-0.42421296, -0.42421296, -0.42421296, ..., -0.42421296,
-0.42421296, -0.42421296],
[-0.42421296, -0.42421296, -0.42421296, ..., -0.42421296,
-0.42421296, -0.42421296],
...,
[-0.42421296, -0.42421296, -0.42421296, ..., -0.42421296,
-0.42421296, -0.42421296],
[-0.42421296, -0.42421296, -0.42421296, ..., -0.42421296,
-0.42421296, -0.42421296],
[-0.42421296, -0.42421296, -0.42421296, ..., -0.42421296,
-0.42421296, -0.42421296]]],
[[[-0.42421296, -0.42421296, -0.42421296, ..., -0.42421296,
-0.42421296, -0.42421296],
[-0.42421296, -0.42421296, -0.42421296, ..., -0.42421296,
-0.42421296, -0.42421296],
[-0.42421296, -0.42421296, -0.42421296, ..., -0.42421296,
-0.42421296, -0.42421296],
...,
[-0.42421296, -0.42421296, -0.42421296, ..., -0.42421296,
-0.42421296, -0.42421296],
[-0.42421296, -0.42421296, -0.42421296, ..., -0.42421296,
-0.42421296, -0.42421296],
[-0.42421296, -0.42421296, -0.42421296, ..., -0.42421296,
-0.42421296, -0.42421296]]],
[[[-0.42421296, -0.42421296, -0.42421296, ..., -0.42421296,
-0.42421296, -0.42421296],
[-0.42421296, -0.42421296, -0.42421296, ..., -0.42421296,
-0.42421296, -0.42421296],
[-0.42421296, -0.42421296, -0.42421296, ..., -0.42421296,
-0.42421296, -0.42421296],
...,
[-0.42421296, -0.42421296, -0.42421296, ..., -0.42421296,
-0.42421296, -0.42421296],
[-0.42421296, -0.42421296, -0.42421296, ..., -0.42421296,
-0.42421296, -0.42421296],
[-0.42421296, -0.42421296, -0.42421296, ..., -0.42421296,
-0.42421296, -0.42421296]]]], dtype=float32)
编辑:您也可以通过next(iter(train_loader))[1].numpy()
获取标签。或者,您可以使用 train_dataset.data.numpy()
和 train_dataset.targets.numpy()
,但请注意,数据不会像使用数据加载器时那样被 transform
转换。
此任务不需要使用 torch.utils.data.DataLoader
。
from torchvision import datasets, transforms
train_set = datasets.MNIST('./data', train=True, download=True)
test_set = datasets.MNIST('./data', train=False, download=True)
train_set_array = train_set.data.numpy()
test_set_array = test_set.data.numpy()
请注意,在这种情况下目标被排除在外。