如何在不复制和粘贴的情况下以原生 Pytorch 格式定义模型并导入 LightningModule?
How to Define Model in Native Pytorch Format and Import Into LightningModule Without Copy and Pasting?
假设我有一个像这样的原生 pytorch 模型
class NormalAutoEncoder(nn.Module)):
def __init__(self):
super().__init__()
self.encoder = nn.Sequential(nn.Linear(28 * 28, 64), nn.ReLU(), nn.Linear(64, 3))
self.decoder = nn.Sequential(nn.Linear(3, 64), nn.ReLU(), nn.Linear(64, 28 * 28))
def forward(self, x):
# in lightning, forward defines the prediction/inference actions
embedding = self.encoder(x)
return embedding
如何在不复制和粘贴的情况下将__init__
和forward
功能(基本上是整个网络)放入pytorch照明模块?
简单。利用Python的继承机制。
如果下面是原生的PyTorch模块
class NormalAutoEncoder(nn.Module):
def __init__(self):
super().__init__()
self.encoder = ...
self.decoder = ...
def forward(self, x):
embedding = ...
return embedding
然后让 也 继承自 NormalAutoEncoder
的新 LightningAutoEncoder
class LightningAutoEncoder(LightningModule, NormalAutoEncoder):
def __init__(self, ...):
LightningModule.__init__(self) # only LightningModule's init
NormalAutoEncoder.__init__(self, ...) # this basically executes __init__() of the NormalAutoEncoder
def forward(self, x):
# offloads its execution to NormalAutoEncoder's forward() method
return NormalAutoEncoder.forward(self, x)
就是这样。禁止复制粘贴。
假设我有一个像这样的原生 pytorch 模型
class NormalAutoEncoder(nn.Module)):
def __init__(self):
super().__init__()
self.encoder = nn.Sequential(nn.Linear(28 * 28, 64), nn.ReLU(), nn.Linear(64, 3))
self.decoder = nn.Sequential(nn.Linear(3, 64), nn.ReLU(), nn.Linear(64, 28 * 28))
def forward(self, x):
# in lightning, forward defines the prediction/inference actions
embedding = self.encoder(x)
return embedding
如何在不复制和粘贴的情况下将__init__
和forward
功能(基本上是整个网络)放入pytorch照明模块?
简单。利用Python的继承机制。
如果下面是原生的PyTorch模块
class NormalAutoEncoder(nn.Module):
def __init__(self):
super().__init__()
self.encoder = ...
self.decoder = ...
def forward(self, x):
embedding = ...
return embedding
然后让 也 继承自 NormalAutoEncoder
LightningAutoEncoder
class LightningAutoEncoder(LightningModule, NormalAutoEncoder):
def __init__(self, ...):
LightningModule.__init__(self) # only LightningModule's init
NormalAutoEncoder.__init__(self, ...) # this basically executes __init__() of the NormalAutoEncoder
def forward(self, x):
# offloads its execution to NormalAutoEncoder's forward() method
return NormalAutoEncoder.forward(self, x)
就是这样。禁止复制粘贴。