torch.utils.data.DataLoader - 为什么它增加了一个维度

torch.utils.data.DataLoader - why it adds a dimension

from torchvision import datasets, transforms
transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5,), (0.5,)),])
trainset = datasets.MNIST('~/.pytorch/MNIST_data/', download=True, train=True, transform=transform)

trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True)            # A
trainloader = torch.utils.data.DataLoader(trainset.train_data, batch_size=64, shuffle=True) # B

dataiter = iter(trainloader)     
images, labels = dataiter.next() # A
images         = dataiter.next() # B
images.shape

为什么上面的代码,方法#A 给出 torch.Size([64, 1, 28, 28]) ,而#B 给出 torch.Size([64, 28, 28])? #A中值为1的第二个维度从哪里来?

提前致谢。

第二个维度描述了灰度为 1 的颜色通道。RGB 图像有 3 个通道(红色、绿色和蓝色),看起来像 64, 3, W, H。 因此,在使用 CNN 时,您的数据通常必须具有 batchsize, channels, width, height 的形状,因此 64, 1, 28, 28 是正确的。