在 Pytorch 中计算 3D CNN 的填充

Calculate padding for 3D CNN in Pytorch

我目前正在尝试将 3D CNN 应用于一组尺寸为 193 x 229 x 193 的图像,并希望通过每个卷积层保留相同的图像尺寸(类似于 tensorflow padding=SAME).我知道填充可以计算如下:

S=Stride
P=Padding
W=Width
K=Kernal size

P = ((S-1)*W-S+K)/2

第一层的填充为 1:

P = ((1-1)*193-1+3)/2
P= 1.0

虽然我也得到了每个后续层的 1.0 结果。有人有什么建议吗?抱歉,初学者!

可重现的例子:

import torch
import torch.nn as nn

x = torch.randn(1, 1, 193, 229, 193)

padding = ((1-1)*96-1+3)/2
print(padding)

x = nn.Conv3d(in_channels=1, out_channels=8, kernel_size=3, padding=1)(x)
print("shape after conv1: " + str(x.shape))
x = nn.Conv3d(in_channels=8, out_channels=8, kernel_size=3,padding=1)(x)
x = nn.BatchNorm3d(8)(x) 
print("shape after conv2 + batch norm: " + str(x.shape))
x = nn.ReLU()(x)
print("shape after reLU:" + str(x.shape))
x = nn.MaxPool3d(kernel_size=2, stride=2)(x)
print("shape after max pool" + str(x.shape))
x = nn.Conv3d(in_channels=8, out_channels=16, kernel_size=3,padding=1)(x)
print("shape after conv3: " + str(x.shape))
x = nn.Conv3d(in_channels=16, out_channels=16, kernel_size=3,padding=1)(x)
print("shape after conv4: " + str(x.shape))

当前输出:

shape after conv1: torch.Size([1, 8, 193, 229, 193])
shape after conv2 + batch norm: torch.Size([1, 8, 193, 229, 193])
shape after reLU:torch.Size([1, 8, 193, 229, 193])
shape after max pooltorch.Size([1, 8, 96, 114, 96])
shape after conv3: torch.Size([1, 16, 96, 114, 96])
shape after conv4: torch.Size([1, 16, 96, 114, 96])

期望的输出:

shape after conv1: torch.Size([1, 8, 193, 229, 193])
shape after conv2 + batch norm: torch.Size([1, 8, 193, 229, 193])
...
shape after conv3: torch.Size([1, 16, 193, 229, 193])
shape after conv4: torch.Size([1, 16, 193, 229, 193])

TLDR;您的公式也适用于 nn.MaxPool3d

您正在使用内核大小为 2(隐式 (2,2,2))的最大池层,步幅为 2(隐式 (2,2,2))。这意味着对于每个 2x2x2 块,您只会获得一个值。换句话说 - 顾名思义:只有来自每个 2x2x2 块的最大值被汇集到输出数组。

这就是为什么你要从 (1, 8, 193, 229, 193)(1, 8, 96, 114, 96)(注意除以 2)。

当然,如果您在 nn.MaxPool3d 上设置 kernel_size=3stride=1,您将保留块的形状。


#x为输入形状,#w为内核形状。如果我们希望输出具有相同的大小,则 #x = floor((#x + 2p - #w)/s + 1) 需要为真。那是 2p = s(#x - 1) - #x + #w = #x(s - 1) + #w - s(你的公式)

因为s = 2#w = 2,所以2p = #x这是不可能的。