使用 3 通道 (RGB) PyTorch 模型对 4 通道 (RGBY) 图像进行分类

Using 3-channel (RGB) PyTorch model for classification 4-channel (RGBY) images

我用 4 通道图像 (RGBY) 标记了数据集。我想使用预训练分类模型(使用 pytorch 和 ResNet50 作为模型)。不过,所有 pytorch 个模型都用于 3 个通道。 所以,问题是:我如何使用3通道预训练模型来处理4通道数据?我正在以接下来的方式加载模型:

import torchvision.models as models
resnet50 = models.resnet50(pretrained=True)

您可以修改 CNN 的第一层,使其需要 4 个输入通道而不是 3 个。在您的例子中,第一层是 resnet50.conv1。所以:

import torchvision.models as models
resnet50 = models.resnet50(pretrained=True)

# modify first layer so it expects 4 input channels; all other parameters unchanged
resnet50.conv1 = torch.nn.Conv2d(4,64,kernel_size = (7,7),stride = (2,2), padding = (3,3), bias = False) 

# test
inp = torch.rand([1,4,512,512])
resnet50.eval()
resnet50.training = False
out = resnet50(inp) # should evaluate without error

以下实现细节使这一变化变得简单:对于 2D 卷积(对于其他维度的卷积也是如此),pytorch 为每个所需的输出平面(特征图)与每个输入平面卷积一个内核。这导致 n_input_planes x n_output_planes 总特征图(在本例中分别为 4 和 64)。然后,Pytorch 对每个输出平面的所有输入平面求和,无论输入平面的数量如何,总计 n_output_planes 个平面。

好消息是,这意味着您可以在不修改第一层之后的网络的情况下添加额外的输入平面(映射)。这(可能在某些情况下)不利的部分是所有输入特征图都被相同地处理,并且来自每个特征图的信息在第一个卷积结束时被完全合并。在某些情况下,可能希望在开始时以不同方式处理输入特征映射,在这种情况下,您需要定义两个单独的 CNN 分支,以便特征不会在每一层添加在一起。