RuntimeError: Expected 4-dimensional input for 4-dimensional weight [64, 3, 7, 7], but got 3-dimensional input of size [3, 32, 32] instead
RuntimeError: Expected 4-dimensional input for 4-dimensional weight [64, 3, 7, 7], but got 3-dimensional input of size [3, 32, 32] instead
总的来说,我是 PyTorch 和神经网络的新手。我试图在 CIFAR-10 数据集上实现来自 torchvision 的 resnet-50 模型。
import torchvision
import torch
import torch.nn as nn
from torch import optim
import os
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import numpy as np
from collections import OrderedDict
import matplotlib.pyplot as plt
transformations=transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))])
trainset=torchvision.datasets.CIFAR10(root='./CIFAR10',download=True,transform=transformations,train=True)
testset=torchvision.datasets.CIFAR10(root='./CIFAR10',download=True,transform=transformations,train=False)
trainloader=DataLoader(dataset=trainset,batch_size=4)
testloader=DataLoader(dataset=testset,batch_size=4)
inputs,labels=next(iter(trainset))
inputs.size()
resnet=torchvision.models.resnet50(pretrained=True)
if torch.cuda.is_available():
resnet=resnet.cuda()
inputs,labels=inputs.cuda(),torch.Tensor(labels).cuda()
outputs=resnet(inputs)
输出
--------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
<ipython-input-6-904acb410fe4> in <module>()
6 inputs,labels=inputs.cuda(),torch.Tensor(labels).cuda()
7
----> 8 outputs=resnet(inputs)
5 frames
/usr/local/lib/python3.6/dist-packages/torch/nn/modules/conv.py in _conv_forward(self, input, weight)
344 _pair(0), self.dilation, self.groups)
345 return F.conv2d(input, weight, self.bias, self.stride,
--> 346 self.padding, self.dilation, self.groups)
347
348 def forward(self, input):
RuntimeError: Expected 4-dimensional input for 4-dimensional weight [64, 3, 7, 7], but got 3-dimensional input of size [3, 32, 32] instead
数据集是否出于某种原因存在问题,如果没有,我该如何提供 4 维输入? ResNet-50 的 torchvision 实现是否不适用于 CIFAR-10?
目前您正在迭代数据集,这就是您获得(3 维)单幅图像的原因。您实际上需要迭代数据加载器以获得 4 维图像批处理。因此,您只需更改以下行:
inputs,labels=next(iter(trainset))
至
inputs,labels=next(iter(trainloader))
总的来说,我是 PyTorch 和神经网络的新手。我试图在 CIFAR-10 数据集上实现来自 torchvision 的 resnet-50 模型。
import torchvision
import torch
import torch.nn as nn
from torch import optim
import os
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import numpy as np
from collections import OrderedDict
import matplotlib.pyplot as plt
transformations=transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))])
trainset=torchvision.datasets.CIFAR10(root='./CIFAR10',download=True,transform=transformations,train=True)
testset=torchvision.datasets.CIFAR10(root='./CIFAR10',download=True,transform=transformations,train=False)
trainloader=DataLoader(dataset=trainset,batch_size=4)
testloader=DataLoader(dataset=testset,batch_size=4)
inputs,labels=next(iter(trainset))
inputs.size()
resnet=torchvision.models.resnet50(pretrained=True)
if torch.cuda.is_available():
resnet=resnet.cuda()
inputs,labels=inputs.cuda(),torch.Tensor(labels).cuda()
outputs=resnet(inputs)
输出
--------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
<ipython-input-6-904acb410fe4> in <module>()
6 inputs,labels=inputs.cuda(),torch.Tensor(labels).cuda()
7
----> 8 outputs=resnet(inputs)
5 frames
/usr/local/lib/python3.6/dist-packages/torch/nn/modules/conv.py in _conv_forward(self, input, weight)
344 _pair(0), self.dilation, self.groups)
345 return F.conv2d(input, weight, self.bias, self.stride,
--> 346 self.padding, self.dilation, self.groups)
347
348 def forward(self, input):
RuntimeError: Expected 4-dimensional input for 4-dimensional weight [64, 3, 7, 7], but got 3-dimensional input of size [3, 32, 32] instead
数据集是否出于某种原因存在问题,如果没有,我该如何提供 4 维输入? ResNet-50 的 torchvision 实现是否不适用于 CIFAR-10?
目前您正在迭代数据集,这就是您获得(3 维)单幅图像的原因。您实际上需要迭代数据加载器以获得 4 维图像批处理。因此,您只需更改以下行:
inputs,labels=next(iter(trainset))
至
inputs,labels=next(iter(trainloader))