如何使用 pytorch 在网络模型中对 3D 体积进行中心裁剪

How can I do the centercrop of 3D volumes inside the network model with pytorch

keras中,有Cropping3D层用于神经网络内部3D体积的centercropping张量。然而,我没能在 pytorch 中找到任何类似的东西,尽管它们有 torchvision.transforms.CenterCrop(size) 用于 2D 图像。

如何在网络内进行裁剪?否则我需要在预处理中做这件事,这是我出于特定原因最不想做的事情。

我是否需要编写自定义层,例如沿每个轴对输入张量进行切片?希望能得到一些启发

在 PyTorch 中,您不一定需要为所有内容编写层,通常您可以在前向传递期间直接执行您想要的操作。在需要计算梯度的火炬张量上操作时,您需要牢记的基本规则是

  1. 不要将 torch 张量转换为其他类型进行计算(例如使用 torch.sum 而不是转换为 numpy 并使用 numpy.sum)。
  2. 不要执行就地操作(例如更改张量的一个元素或使用就地运算符,因此请使用 x = x + ... 而不是 x += ...)。

就是说,你可以只使用切片,也许它看起来像这样

def forward(self, x):
    ...
    x = self.conv3(x)
    x = x[:, :, 5:20, 5:20]    # crop out part of the feature map
    x = self.relu3(x)
    ...