RuntimeError: Expected 4-dimensional input for 4-dimensional weight [64, 3, 7, 7], but got 3-dimensional input of size [3, 32, 32] instead

RuntimeError: Expected 4-dimensional input for 4-dimensional weight [64, 3, 7, 7], but got 3-dimensional input of size [3, 32, 32] instead

总的来说,我是 PyTorch 和神经网络的新手。我试图在 CIFAR-10 数据集上实现来自 torchvision 的 resnet-50 模型。

import torchvision
import torch
import torch.nn as nn
from torch import optim
import os
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import numpy as np
from collections import OrderedDict
import matplotlib.pyplot as plt

transformations=transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))])
trainset=torchvision.datasets.CIFAR10(root='./CIFAR10',download=True,transform=transformations,train=True)

testset=torchvision.datasets.CIFAR10(root='./CIFAR10',download=True,transform=transformations,train=False)

trainloader=DataLoader(dataset=trainset,batch_size=4)
testloader=DataLoader(dataset=testset,batch_size=4)

inputs,labels=next(iter(trainset))
inputs.size()
resnet=torchvision.models.resnet50(pretrained=True)

if torch.cuda.is_available():
  resnet=resnet.cuda()
  inputs,labels=inputs.cuda(),torch.Tensor(labels).cuda()

outputs=resnet(inputs)

输出

--------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-6-904acb410fe4> in <module>()
      6   inputs,labels=inputs.cuda(),torch.Tensor(labels).cuda()
      7 
----> 8 outputs=resnet(inputs)

5 frames
/usr/local/lib/python3.6/dist-packages/torch/nn/modules/conv.py in _conv_forward(self, input, weight)
    344                             _pair(0), self.dilation, self.groups)
    345         return F.conv2d(input, weight, self.bias, self.stride,
--> 346                         self.padding, self.dilation, self.groups)
    347 
    348     def forward(self, input):

RuntimeError: Expected 4-dimensional input for 4-dimensional weight [64, 3, 7, 7], but got 3-dimensional input of size [3, 32, 32] instead

数据集是否出于某种原因存在问题,如果没有,我该如何提供 4 维输入? ResNet-50 的 torchvision 实现是否不适用于 CIFAR-10?

目前您正在迭代数据集,这就是您获得(3 维)单幅图像的原因。您实际上需要迭代数据加载器以获得 4 维图像批处理。因此,您只需更改以下行:

inputs,labels=next(iter(trainset))

inputs,labels=next(iter(trainloader))