torch.utils.data.DataLoader - 为什么它增加了一个维度
torch.utils.data.DataLoader - why it adds a dimension
from torchvision import datasets, transforms
transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5,), (0.5,)),])
trainset = datasets.MNIST('~/.pytorch/MNIST_data/', download=True, train=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True) # A
trainloader = torch.utils.data.DataLoader(trainset.train_data, batch_size=64, shuffle=True) # B
dataiter = iter(trainloader)
images, labels = dataiter.next() # A
images = dataiter.next() # B
images.shape
为什么上面的代码,方法#A 给出 torch.Size([64, 1, 28, 28]) ,而#B 给出 torch.Size([64, 28, 28])? #A中值为1的第二个维度从哪里来?
提前致谢。
第二个维度描述了灰度为 1 的颜色通道。RGB 图像有 3 个通道(红色、绿色和蓝色),看起来像 64, 3, W, H
。
因此,在使用 CNN 时,您的数据通常必须具有 batchsize, channels, width, height
的形状,因此 64, 1, 28, 28
是正确的。
from torchvision import datasets, transforms
transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5,), (0.5,)),])
trainset = datasets.MNIST('~/.pytorch/MNIST_data/', download=True, train=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True) # A
trainloader = torch.utils.data.DataLoader(trainset.train_data, batch_size=64, shuffle=True) # B
dataiter = iter(trainloader)
images, labels = dataiter.next() # A
images = dataiter.next() # B
images.shape
为什么上面的代码,方法#A 给出 torch.Size([64, 1, 28, 28]) ,而#B 给出 torch.Size([64, 28, 28])? #A中值为1的第二个维度从哪里来?
提前致谢。
第二个维度描述了灰度为 1 的颜色通道。RGB 图像有 3 个通道(红色、绿色和蓝色),看起来像 64, 3, W, H
。
因此,在使用 CNN 时,您的数据通常必须具有 batchsize, channels, width, height
的形状,因此 64, 1, 28, 28
是正确的。