将模型从 3 通道 (RGB) 重新训练为 4 通道 (RGBA),我可以使用 3 通道权重吗?

Retraining a Model from 3 Channels (RGB) to 4 Channels (RGBA), can I use the 3 channel weights?

我需要将模型从 RGB 扩展到 RGBA。我可以处理模型上的代码重写,但与其从头开始重新训练整个模型,我更愿意从它的 3 通道权重 + 零开始。

有没有简单的方法可以将 torch 保存的 3 个通道权重更改为 4 个?

是的,你可以做一点“模型手术”。假设模型的输入仅由卷积层直接处理,那么您只需将该卷积层替换为另一个 in_channels 设置为 4 的层即可。然后您可以将权重设置为零并从原始 conv 层复制旧权重(和偏差,如果适用)。

例如,假设我们有一个看起来像这样的简单模型

import torch
import torch.nn as nn
import torch.nn.functional as F

class SimpleModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 10, kernel_size=3, padding=1, bias=True)
        self.conv2 = nn.Conv2d(10, 5, kernel_size=3, padding=1, bias=True)
        self.linear = nn.Linear(125, 1)
    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        return self.linear(x.flatten(start_dim=1))

model = SimpleModel()

假设此时训练模型,我们可以进行如下手术

y_rgb = torch.randn(1, 3, 5, 5)

# get performance on initial z_rgb
z_rgb = model(y_rgb)

# perform model surgery
with torch.no_grad():
    new_conv1 = nn.Conv2d(4, 10, kernel_size=3, padding=1, bias=True)
    new_conv1.weight.zero_()
    new_conv1.weight[:,:3,...]=model.conv1.weight
    new_conv1.bias.copy_(model.conv1.bias)
    model.conv1 = new_conv1

# add a random alpha channel to y_rgba
y_alpha = torch.randn(1,1,5,5)
y_rgba = torch.cat([y_rgb, y_alpha], dim=1)

# get results on rgba model
z_rgba = model(y_rgba)

# compare z_rgb and z_rgba, print mean-square difference
z_err = ((z_rgba-z_rgb)**2).mean().item()
print('Err:', z_err)

# save results to a new file
torch.save(model.state_dict(), 'checkpoint_rgba.pt')

这应该给你一个零或非常接近零的错误。

当然,如果您在第一个转换层中没有 bias 术语,则无需复制它。

假设您已经保存了新的状态字典,那么您可能想要更新模型 class 定义,以便您的输入卷积层采用 4 个通道输入而不是 3 个。那么下次您可以直接无需额外步骤即可加载新的状态字典。


现在不需要直接对模型进行手术了。虽然我更喜欢它,因为我发现它更容易验证正确性。

假设您保存了 RGB 模型的状态字典,您也可以直接修改状态字典。

# assuming you saved RGB model using torch.save(model.state_dict(), 'checkpoint_rgb.pt')
state_dict = torch.load('checkpoint_rgb.pt')
old_weight = state_dict['conv1.weight']
state_dict['conv1.weight'] = torch.zeros(
    old_weight.shape[0],
    old_weight.shape[1]+1,
    old_weight.shape[2],
    old_weight.shape[3]
).type_as(old_weight)
state_dict['conv1.weight'][:,:3,...] = old_weight
torch.save(state_dict, 'checkpoint_rgba.pt')