RuntimeError: mat1 and mat2 shapes cannot be multiplied (5400x64 and 5400x64)

RuntimeError: mat1 and mat2 shapes cannot be multiplied (5400x64 and 5400x64)

我正在研究图像分类网络,遇到了 forward() 函数中输入和输出正确值的问题。我没有解决这个问题的想法,因为它们对我来说似乎是一样的。错误来自这一行: x = F.relu(self.fc1(x)),但我想不通

谁能帮我解决这个问题?

这是我的代码:

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 8, kernel_size=2)
        self.conv2 = nn.Conv2d(8, 12, kernel_size=2)
        self.conv3 = nn.Conv2d(12, 18, kernel_size=2)
        self.conv4 = nn.Conv2d(18, 24, kernel_size=2)
        self.fc1 = nn.Linear(5400, 64)
        self.fc2 = nn.Linear(64, 2)

    def forward(self, x):
        print(f'1. {x.size()}')
        x = self.conv1(x)
        x = F.max_pool2d(x, 2)
        x = F.relu(x)
        print(f'2. {x.size()}')
        x = self.conv2(x)
        x = F.max_pool2d(x, 2)
        x = F.relu(x)
        print(f'3. {x.size()}')
        x = self.conv3(x)
        x = F.max_pool2d(x, 2)
        x = F.relu(x)
        print(f'4. {x.size()}')
        x = self.conv4(x)
        x = F.max_pool2d(x, 2)
        x = F.relu(x)
        print(f'5. {x.size()}')
        x = x.view(-1, x.size(0)) 
        print(f'6. {x.size()}')
        x = F.relu(self.fc1(x))
        print(f'7. {x.size()}')
        x = self.fc2(x)
        print(f'8. {x.size()}')
        
        return torch.sigmoid(x)

这是打印输出:

1. torch.Size([64, 3, 256, 256])
2. torch.Size([64, 8, 127, 127])
3. torch.Size([64, 12, 63, 63])
4. torch.Size([64, 18, 31, 31])
5. torch.Size([64, 24, 15, 15])
6. torch.Size([5400, 64])

我想改变

x = x.view(-1, x.size(0))

x = x.view([-1, 5400], x.size(0))

将解决您的问题,您在打印 6 中看到:

6. torch.Size([5400, 64])

批量大小 641 轴中,而不在 0 轴中。全连接层需要大小为 5400 的输入,因此更改它可能会解决问题,因为您不知道批量大小,但您知道 fully-connected 的输入是 5400.