如何在不复制和粘贴的情况下以原生 Pytorch 格式定义模型并导入 LightningModule?

How to Define Model in Native Pytorch Format and Import Into LightningModule Without Copy and Pasting?

假设我有一个像这样的原生 pytorch 模型

class NormalAutoEncoder(nn.Module)):
    def __init__(self):
        super().__init__()
        self.encoder = nn.Sequential(nn.Linear(28 * 28, 64), nn.ReLU(), nn.Linear(64, 3))
        self.decoder = nn.Sequential(nn.Linear(3, 64), nn.ReLU(), nn.Linear(64, 28 * 28))

    def forward(self, x):
        # in lightning, forward defines the prediction/inference actions
        embedding = self.encoder(x)
        return embedding

如何在不复制和粘贴的情况下将__init__forward功能(基本上是整个网络)放入pytorch照明模块?

简单。利用Python的继承机制。

如果下面是原生的PyTorch模块

class NormalAutoEncoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.encoder = ...
        self.decoder = ...

    def forward(self, x):
        embedding = ...
        return embedding

然后让 继承自 NormalAutoEncoder

的新 LightningAutoEncoder
class LightningAutoEncoder(LightningModule, NormalAutoEncoder):
    def __init__(self, ...):
        LightningModule.__init__(self) # only LightningModule's init
        NormalAutoEncoder.__init__(self, ...) # this basically executes __init__() of the NormalAutoEncoder

    def forward(self, x):
        # offloads its execution to NormalAutoEncoder's forward() method
        return NormalAutoEncoder.forward(self, x)

就是这样。禁止复制粘贴。