如何使用 pytorch 在网络模型中对 3D 体积进行中心裁剪
How can I do the centercrop of 3D volumes inside the network model with pytorch
在keras
中,有Cropping3D
层用于神经网络内部3D体积的centercropping张量。然而,我没能在 pytorch 中找到任何类似的东西,尽管它们有 torchvision.transforms.CenterCrop(size)
用于 2D 图像。
如何在网络内进行裁剪?否则我需要在预处理中做这件事,这是我出于特定原因最不想做的事情。
我是否需要编写自定义层,例如沿每个轴对输入张量进行切片?希望能得到一些启发
在 PyTorch 中,您不一定需要为所有内容编写层,通常您可以在前向传递期间直接执行您想要的操作。在需要计算梯度的火炬张量上操作时,您需要牢记的基本规则是
- 不要将 torch 张量转换为其他类型进行计算(例如使用
torch.sum
而不是转换为 numpy 并使用 numpy.sum
)。
- 不要执行就地操作(例如更改张量的一个元素或使用就地运算符,因此请使用
x = x + ...
而不是 x += ...
)。
就是说,你可以只使用切片,也许它看起来像这样
def forward(self, x):
...
x = self.conv3(x)
x = x[:, :, 5:20, 5:20] # crop out part of the feature map
x = self.relu3(x)
...
在keras
中,有Cropping3D
层用于神经网络内部3D体积的centercropping张量。然而,我没能在 pytorch 中找到任何类似的东西,尽管它们有 torchvision.transforms.CenterCrop(size)
用于 2D 图像。
如何在网络内进行裁剪?否则我需要在预处理中做这件事,这是我出于特定原因最不想做的事情。
我是否需要编写自定义层,例如沿每个轴对输入张量进行切片?希望能得到一些启发
在 PyTorch 中,您不一定需要为所有内容编写层,通常您可以在前向传递期间直接执行您想要的操作。在需要计算梯度的火炬张量上操作时,您需要牢记的基本规则是
- 不要将 torch 张量转换为其他类型进行计算(例如使用
torch.sum
而不是转换为 numpy 并使用numpy.sum
)。 - 不要执行就地操作(例如更改张量的一个元素或使用就地运算符,因此请使用
x = x + ...
而不是x += ...
)。
就是说,你可以只使用切片,也许它看起来像这样
def forward(self, x):
...
x = self.conv3(x)
x = x[:, :, 5:20, 5:20] # crop out part of the feature map
x = self.relu3(x)
...