将模型从 3 通道 (RGB) 重新训练为 4 通道 (RGBA),我可以使用 3 通道权重吗?
Retraining a Model from 3 Channels (RGB) to 4 Channels (RGBA), can I use the 3 channel weights?
我需要将模型从 RGB 扩展到 RGBA。我可以处理模型上的代码重写,但与其从头开始重新训练整个模型,我更愿意从它的 3 通道权重 + 零开始。
有没有简单的方法可以将 torch 保存的 3 个通道权重更改为 4 个?
是的,你可以做一点“模型手术”。假设模型的输入仅由卷积层直接处理,那么您只需将该卷积层替换为另一个 in_channels
设置为 4
的层即可。然后您可以将权重设置为零并从原始 conv 层复制旧权重(和偏差,如果适用)。
例如,假设我们有一个看起来像这样的简单模型
import torch
import torch.nn as nn
import torch.nn.functional as F
class SimpleModel(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(3, 10, kernel_size=3, padding=1, bias=True)
self.conv2 = nn.Conv2d(10, 5, kernel_size=3, padding=1, bias=True)
self.linear = nn.Linear(125, 1)
def forward(self, x):
x = F.relu(self.conv1(x))
x = F.relu(self.conv2(x))
return self.linear(x.flatten(start_dim=1))
model = SimpleModel()
假设此时训练模型,我们可以进行如下手术
y_rgb = torch.randn(1, 3, 5, 5)
# get performance on initial z_rgb
z_rgb = model(y_rgb)
# perform model surgery
with torch.no_grad():
new_conv1 = nn.Conv2d(4, 10, kernel_size=3, padding=1, bias=True)
new_conv1.weight.zero_()
new_conv1.weight[:,:3,...]=model.conv1.weight
new_conv1.bias.copy_(model.conv1.bias)
model.conv1 = new_conv1
# add a random alpha channel to y_rgba
y_alpha = torch.randn(1,1,5,5)
y_rgba = torch.cat([y_rgb, y_alpha], dim=1)
# get results on rgba model
z_rgba = model(y_rgba)
# compare z_rgb and z_rgba, print mean-square difference
z_err = ((z_rgba-z_rgb)**2).mean().item()
print('Err:', z_err)
# save results to a new file
torch.save(model.state_dict(), 'checkpoint_rgba.pt')
这应该给你一个零或非常接近零的错误。
当然,如果您在第一个转换层中没有 bias
术语,则无需复制它。
假设您已经保存了新的状态字典,那么您可能想要更新模型 class 定义,以便您的输入卷积层采用 4 个通道输入而不是 3 个。那么下次您可以直接无需额外步骤即可加载新的状态字典。
现在不需要直接对模型进行手术了。虽然我更喜欢它,因为我发现它更容易验证正确性。
假设您保存了 RGB 模型的状态字典,您也可以直接修改状态字典。
# assuming you saved RGB model using torch.save(model.state_dict(), 'checkpoint_rgb.pt')
state_dict = torch.load('checkpoint_rgb.pt')
old_weight = state_dict['conv1.weight']
state_dict['conv1.weight'] = torch.zeros(
old_weight.shape[0],
old_weight.shape[1]+1,
old_weight.shape[2],
old_weight.shape[3]
).type_as(old_weight)
state_dict['conv1.weight'][:,:3,...] = old_weight
torch.save(state_dict, 'checkpoint_rgba.pt')
我需要将模型从 RGB 扩展到 RGBA。我可以处理模型上的代码重写,但与其从头开始重新训练整个模型,我更愿意从它的 3 通道权重 + 零开始。
有没有简单的方法可以将 torch 保存的 3 个通道权重更改为 4 个?
是的,你可以做一点“模型手术”。假设模型的输入仅由卷积层直接处理,那么您只需将该卷积层替换为另一个 in_channels
设置为 4
的层即可。然后您可以将权重设置为零并从原始 conv 层复制旧权重(和偏差,如果适用)。
例如,假设我们有一个看起来像这样的简单模型
import torch
import torch.nn as nn
import torch.nn.functional as F
class SimpleModel(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(3, 10, kernel_size=3, padding=1, bias=True)
self.conv2 = nn.Conv2d(10, 5, kernel_size=3, padding=1, bias=True)
self.linear = nn.Linear(125, 1)
def forward(self, x):
x = F.relu(self.conv1(x))
x = F.relu(self.conv2(x))
return self.linear(x.flatten(start_dim=1))
model = SimpleModel()
假设此时训练模型,我们可以进行如下手术
y_rgb = torch.randn(1, 3, 5, 5)
# get performance on initial z_rgb
z_rgb = model(y_rgb)
# perform model surgery
with torch.no_grad():
new_conv1 = nn.Conv2d(4, 10, kernel_size=3, padding=1, bias=True)
new_conv1.weight.zero_()
new_conv1.weight[:,:3,...]=model.conv1.weight
new_conv1.bias.copy_(model.conv1.bias)
model.conv1 = new_conv1
# add a random alpha channel to y_rgba
y_alpha = torch.randn(1,1,5,5)
y_rgba = torch.cat([y_rgb, y_alpha], dim=1)
# get results on rgba model
z_rgba = model(y_rgba)
# compare z_rgb and z_rgba, print mean-square difference
z_err = ((z_rgba-z_rgb)**2).mean().item()
print('Err:', z_err)
# save results to a new file
torch.save(model.state_dict(), 'checkpoint_rgba.pt')
这应该给你一个零或非常接近零的错误。
当然,如果您在第一个转换层中没有 bias
术语,则无需复制它。
假设您已经保存了新的状态字典,那么您可能想要更新模型 class 定义,以便您的输入卷积层采用 4 个通道输入而不是 3 个。那么下次您可以直接无需额外步骤即可加载新的状态字典。
现在不需要直接对模型进行手术了。虽然我更喜欢它,因为我发现它更容易验证正确性。
假设您保存了 RGB 模型的状态字典,您也可以直接修改状态字典。
# assuming you saved RGB model using torch.save(model.state_dict(), 'checkpoint_rgb.pt')
state_dict = torch.load('checkpoint_rgb.pt')
old_weight = state_dict['conv1.weight']
state_dict['conv1.weight'] = torch.zeros(
old_weight.shape[0],
old_weight.shape[1]+1,
old_weight.shape[2],
old_weight.shape[3]
).type_as(old_weight)
state_dict['conv1.weight'][:,:3,...] = old_weight
torch.save(state_dict, 'checkpoint_rgba.pt')