训练pytorch模型后如何丢弃分支

How to discard a branch after training a pytorch model

我正在尝试在 pytorch 中实现一个 FCN,其整体结构如下:

到目前为止的代码如下所示:

class SNet(nn.Module):
    def __init__(self):
        super(SNet, self).__init__()
        
        self.enc_a = encoder(...)
        self.dec_a = decoder(...)
        
        self.enc_b = encoder(...)
        self.dec_b = decoder(...)
    
    def forward(self, x1, x2):
        x1 = self.enc_a(x1)
        x2 = self.enc_b(x2)
        x2 = self.dec_b(x2)
        x1 = self.dec_a(torch.cat((x1, x2), dim=-1)
        return x1, x2

keras 中,使用函数 API 相对容易做到这一点。但是,我在 pytorch.

中找不到任何具体示例/教程来执行此操作
  1. 如何在训练后丢弃 dec_a(自动编码器分支的解码器部分)?
  2. 在联合训练期间,loss 将是来自两个分支的 loss 的总和(可选加权)?

您还可以为训练和推理模型定义单独的模式:

class SNet(nn.Module):
  def __init__(self):
    super(SNet, self).__init__()
    
    self.enc_a = encoder(...)
    self.dec_a = decoder(...)
    
    self.enc_b = encoder(...)
    self.dec_b = decoder(...)
    
    self.training = True

  def forward(self, x1, x2):
    if self.training:
        x1 = self.enc_a(x1)
        x2 = self.enc_b(x2)
        x2 = self.dec_b(x2)
        x1 = self.dec_a(torch.cat((x1, x2), dim=-1)
        return x1, x2
    else:
        x1 = self.enc_a(x1)
        x2 = self.enc_b(x2)
        x2 = self.dec_b(x2)
        return x2

这些块是示例,可能无法完全满足您的要求,因为我认为您在块图中定义训练和推理操作的方式与您的代码之间存在一些歧义,但无论如何您都会得到如何仅在训练模式下使用某些模块的想法。然后你可以相应地设置这个变量。