什么相当于 pytorch lstm num_layers?

What is equivalent to pytorch lstm num_layers?

我是 PyTorch 的初学者。从 lstm description,我了解到我可以通过以下方式创建具有 3 层的堆叠 lstm:

layer = torch.nn.LSTM(128, 512, num_layers=3)

然后在forward函数中,我可以做:

def forward(x, state):
    x, state = layer(x, state)
    return x, (state[0].detach(), state[1].detach())

而且我可以 state 从一批到另一批。
但是如果我创建 3 个 lstm 层,如果我想自己实现相同的堆叠层,那相当于什么?

layer1 = torch.nn.LSTM(128, 512, num_layers=1)
layer2 = torch.nn.LSTM(128, 512, num_layers=1)
layer3 = torch.nn.LSTM(128, 512, num_layers=1)

在这种情况下,forward 函数应该输入什么并得到返回的 state
我还尝试查看 pytorch lstm 的 source code,但在 forward 函数中它调用了一个 _VF 模块,我找不到它的定义位置。

如果将state定义为3层状态的列表,则

def forward(x, state):
    x, s0 = layer1(x, state[0])
    x, s1 = layer2(x, state[1])
    x, s2 = layer3(x, state[2])
    return x, [s0.detach(), s1.detach(), s2.detach()]