pytorch seq2seq编码器前向方法
pytorch seq2seq encoder forward method
我正在关注 Pytorch seq2seq tutorial 下面是他们如何定义编码器功能。
class EncoderRNN(nn.Module):
def __init__(self, input_size, hidden_size):
super(EncoderRNN, self).__init__()
self.hidden_size = hidden_size
self.embedding = nn.Embedding(input_size, hidden_size)
self.gru = nn.GRU(hidden_size, hidden_size)
def forward(self, input, hidden):
embedded = self.embedding(input).view(1, 1, -1)
output = embedded
output, hidden = self.gru(output, hidden)
return output, hidden
def initHidden(self):
return torch.zeros(1, 1, self.hidden_size, device=device)
但是,似乎 forward
方法在训练期间从未真正被调用过。
以下是编码器前向方法在教程中的使用方式:
for ei in range(input_length):
encoder_output, encoder_hidden = encoder(input_tensor[ei], encoder_hidden)
encoder_outputs[ei] = encoder_output[0, 0]
不应该是 encoder.forward
而不是 encoder
吗?
Pytorch 中是否有一些我不知道的自动 'forward' 机制?
在 PyTorch 中,您通过扩展 torch.nn.Module
编写自己的 class 并定义 forward 方法来表达您想要的计算步骤,作为 "paperwork"(例如调用钩子) model.__call__(...)
方法(model(x) 将通过 python 特殊名称规范调用的方法)。
如果你很好奇,你可以看看 model(x)
除了调用 model.forward(x)
之外在幕后做了什么:https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/module.py#L462
此外,您可以在此处了解显式调用 .foward(x)
方法与仅使用 model(x)
之间的区别:https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/module.py#L72
我正在关注 Pytorch seq2seq tutorial 下面是他们如何定义编码器功能。
class EncoderRNN(nn.Module):
def __init__(self, input_size, hidden_size):
super(EncoderRNN, self).__init__()
self.hidden_size = hidden_size
self.embedding = nn.Embedding(input_size, hidden_size)
self.gru = nn.GRU(hidden_size, hidden_size)
def forward(self, input, hidden):
embedded = self.embedding(input).view(1, 1, -1)
output = embedded
output, hidden = self.gru(output, hidden)
return output, hidden
def initHidden(self):
return torch.zeros(1, 1, self.hidden_size, device=device)
但是,似乎 forward
方法在训练期间从未真正被调用过。
以下是编码器前向方法在教程中的使用方式:
for ei in range(input_length):
encoder_output, encoder_hidden = encoder(input_tensor[ei], encoder_hidden)
encoder_outputs[ei] = encoder_output[0, 0]
不应该是 encoder.forward
而不是 encoder
吗?
Pytorch 中是否有一些我不知道的自动 'forward' 机制?
在 PyTorch 中,您通过扩展 torch.nn.Module
编写自己的 class 并定义 forward 方法来表达您想要的计算步骤,作为 "paperwork"(例如调用钩子) model.__call__(...)
方法(model(x) 将通过 python 特殊名称规范调用的方法)。
如果你很好奇,你可以看看 model(x)
除了调用 model.forward(x)
之外在幕后做了什么:https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/module.py#L462
此外,您可以在此处了解显式调用 .foward(x)
方法与仅使用 model(x)
之间的区别:https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/module.py#L72