如何连接 2 个 PyTorch 模型并使第一个在 PyTorch 中不可训练

How to concatenate 2 pytorch models and make the first one non-trainable in PyTorch

我有两个网络,我需要为我的完整模型连接它们。但是我的第一个模型是预训练的,我需要在训练完整模型时让它不可训练。我怎样才能在 PyTorch 中实现这一点。

我可以使用 this answer

连接两个模型
class MyModelA(nn.Module):
    def __init__(self):
        super(MyModelA, self).__init__()
        self.fc1 = nn.Linear(10, 2)
        
    def forward(self, x):
        x = self.fc1(x)
        return x
    

class MyModelB(nn.Module):
    def __init__(self):
        super(MyModelB, self).__init__()
        self.fc1 = nn.Linear(20, 2)
        
    def forward(self, x):
        x = self.fc1(x)
        return x


class MyEnsemble(nn.Module):
    def __init__(self, modelA, modelB):
        super(MyEnsemble, self).__init__()
        self.modelA = modelA
        self.modelB = modelB
        
    def forward(self, x):
        x1 = self.modelA(x)
        x2 = self.modelB(x1)
        return x2

# Create models and load state_dicts    
modelA = MyModelA()
modelB = MyModelB()
# Load state dicts
modelA.load_state_dict(torch.load(PATH))

model = MyEnsemble(modelA, modelB)
x = torch.randn(1, 10)
output = model(x)

基本上在这里,我想加载预训练的 modelA 并在训练 Ensemble 模型时使其不可训练。

您可以通过将 requires_grad 设置为 false 来冻结您不想训练的模型的所有参数。 像这样:

for param in model.parameters():
    param.requires_grad = False

这应该适合你。

另一种方法是在你的训练循环中处理这个问题:

modelA = MyModelA()
modelB = MyModelB()

criterionB = nn.MSELoss()
optimizerB = torch.optim.Adam(modelB.parameters(), lr=0.001)

for epoch in range(epochs):
    for samples, targets in dataloader:
        optimizerB.zero_grad()

        x = modelA.train()(samples)
        predictions = modelB.train()(samples)
    
        loss = criterionB(predictions, targets)
        loss.backward()
        optimizerB.step()

因此,您将模型 A 的输出传递给模型 B,但您只优化了模型 B。

一个简单的方法是 detach 您不想更新的模型的输出张量,它不会将梯度反向传播到连接的模型。在您的情况下,您可以在 MyEnsemble 模型的前向函数中与 x1 连接之前简单地 detach x2 张量,以保持 modelB 的权重不变。

因此,新的前向函数应该如下所示:

def forward(self, x1, x2):
        x1 = self.modelA(x1)
        x2 = self.modelB(x2)
        x = torch.cat((x1, x2.detach()), dim=1)  # Detaching x2, so modelB wont be updated
        x = self.classifier(F.relu(x))
        return x