pytorch中LSTM的参数和函数调用

arguments and function call of LSTM in pytorch

谁能给我解释一下下面的代码:

import torch
import torch.nn as nn

input = torch.randn(5, 3, 10)
h0 = torch.randn(2, 3, 20)
c0 = torch.randn(2, 3, 20)

rnn = nn.LSTM(10,20,2)

output, (hn, cn) = rnn(input, (h0, c0))
print(input)

在调用 rnn rnn(input, (h0, c0)) 时,我们在括号中给出了参数 h0 和 c0。这是什么意思?如果 (h0, c0) 表示单个值,那么该值是什么,这里传递的第三个参数是什么? 但是,在行 rnn = nn.LSTM(10,20,2) 中,我们在 LSTM 函数中传递参数而不带括号。 谁能解释一下这个函数调用是如何工作的?

赋值 rnn = nn.LSTM(10, 20, 2) 使用 nn.LSTM class 实例化了一个新的 nn.Module。它的前三个参数是input_size(此处为10)、hidden_size(此处为20)和num_layers(此处为2)。

另一方面 rnn(input, (h0, c0)) 对应于实际调用 class 实例, i.e. 运行 __call__ 大致相当于 forward 该模块的功能。 nn.LSTM__call__ 方法接受两个参数:input(形 (sequnce_length, batch_size, input_size) 和两个张量 (h_0, c_0) 的元组(形如 (num_layers, batch_size, hidden_size) nn.LSTM)

的基本用例中

请在使用内置函数时参考 PyTorch 文档,您会找到参数列表的确切定义(用于初始化 class 实例的参数)以及 input/outputs 规范(每当用那个模块进行推断时)。


您可能对符号感到困惑,这里有一个小例子可以帮助您:

  • 元组作为输入:

    def fn1(x, p):
        a, b = p # unpack input
        return a*x + b
    
    >>> fn1(2, (3, 1))
    >>> 7
    
  • 元组作为输出

    def fn2(x):
        return x, (3*x, x**2) # actually output is a tuple of int and tuple 
    
    >>> x, (a, b) = fn2(2) # unpacking
    (2, (6, 4))
    
    >>> x, a, b
    (2, 6, 4)