MNIST、torchvision 中的输出和广播形状不匹配
Output and Broadcast shape mismatch in MNIST, torchvision
在 Torchvision 中使用 MNIST 数据集时出现以下错误
RuntimeError: output with shape [1, 28, 28] doesn't match the broadcast shape [3, 28, 28]
这是我的代码:
import torch
from torchvision import datasets, transforms
transform = transforms.Compose([transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])
trainset = datasets.MNIST('~/.pytorch/MNIST_data/', download=True, train=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True)
images, labels = next(iter(trainloader))
错误是由于数据集上的颜色与灰度,数据集是灰度。
我通过将转换更改为
来修复它
transform = transforms.Compose([transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))
])
在 Torchvision 中使用 MNIST 数据集时出现以下错误
RuntimeError: output with shape [1, 28, 28] doesn't match the broadcast shape [3, 28, 28]
这是我的代码:
import torch
from torchvision import datasets, transforms
transform = transforms.Compose([transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])
trainset = datasets.MNIST('~/.pytorch/MNIST_data/', download=True, train=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True)
images, labels = next(iter(trainloader))
错误是由于数据集上的颜色与灰度,数据集是灰度。
我通过将转换更改为
来修复它transform = transforms.Compose([transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))
])