使用 Pytorch 从自动编码器中提取隐藏表示

Extracting hidden representations from an autoencoder using Pytorch

用 PyTorch 训练自动编码器后,如何提取输入特征在某个隐藏层的低维嵌入?

您可以只定义您的模型,以便它可以选择 returns 在前向传递期间计算的中间 pytorch 变量。简单例子:

class Autoencoder(nn.Module):
    def __init__(self, input_size, hidden_size):
    super().__init__()
    self.encoder = nn.Sequential(
    nn.Linear(input_size, hidden_size),
    nn.ReLU(),
    nn.Linear(hidden_size, 3)) #reduce the size

    self.decoder = nn.Sequential(
    nn.Linear(3, hidden_size),
    nn.ReLU(),
    nn.Linear(hidden_size, input_size),
    nn.ReLU()) #reduce the size

def forward(self, x, return_encoding = False):
    encoded = self.encoder(x)
    decoded = self.decoder(encoded)

    if return_encoding:
        return decoded,encoded
    return decoded