pytorch中LSTM的参数和函数调用
arguments and function call of LSTM in pytorch
谁能给我解释一下下面的代码:
import torch
import torch.nn as nn
input = torch.randn(5, 3, 10)
h0 = torch.randn(2, 3, 20)
c0 = torch.randn(2, 3, 20)
rnn = nn.LSTM(10,20,2)
output, (hn, cn) = rnn(input, (h0, c0))
print(input)
在调用 rnn rnn(input, (h0, c0))
时,我们在括号中给出了参数 h0 和 c0。这是什么意思?如果 (h0, c0) 表示单个值,那么该值是什么,这里传递的第三个参数是什么?
但是,在行 rnn = nn.LSTM(10,20,2)
中,我们在 LSTM 函数中传递参数而不带括号。
谁能解释一下这个函数调用是如何工作的?
赋值 rnn = nn.LSTM(10, 20, 2)
使用 nn.LSTM
class 实例化了一个新的 nn.Module
。它的前三个参数是input_size
(此处为10
)、hidden_size
(此处为20
)和num_layers
(此处为2
)。
另一方面 rnn(input, (h0, c0))
对应于实际调用 class 实例, i.e.
运行 __call__
大致相当于 forward
该模块的功能。 nn.LSTM
的 __call__
方法接受两个参数:input
(形 (sequnce_length, batch_size, input_size)
和两个张量 (h_0, c_0)
的元组(形如 (num_layers, batch_size, hidden_size)
在nn.LSTM
)
的基本用例中
请在使用内置函数时参考 PyTorch 文档,您会找到参数列表的确切定义(用于初始化 class 实例的参数)以及 input/outputs 规范(每当用那个模块进行推断时)。
您可能对符号感到困惑,这里有一个小例子可以帮助您:
元组作为输入:
def fn1(x, p):
a, b = p # unpack input
return a*x + b
>>> fn1(2, (3, 1))
>>> 7
元组作为输出
def fn2(x):
return x, (3*x, x**2) # actually output is a tuple of int and tuple
>>> x, (a, b) = fn2(2) # unpacking
(2, (6, 4))
>>> x, a, b
(2, 6, 4)
谁能给我解释一下下面的代码:
import torch
import torch.nn as nn
input = torch.randn(5, 3, 10)
h0 = torch.randn(2, 3, 20)
c0 = torch.randn(2, 3, 20)
rnn = nn.LSTM(10,20,2)
output, (hn, cn) = rnn(input, (h0, c0))
print(input)
在调用 rnn rnn(input, (h0, c0))
时,我们在括号中给出了参数 h0 和 c0。这是什么意思?如果 (h0, c0) 表示单个值,那么该值是什么,这里传递的第三个参数是什么?
但是,在行 rnn = nn.LSTM(10,20,2)
中,我们在 LSTM 函数中传递参数而不带括号。
谁能解释一下这个函数调用是如何工作的?
赋值 rnn = nn.LSTM(10, 20, 2)
使用 nn.LSTM
class 实例化了一个新的 nn.Module
。它的前三个参数是input_size
(此处为10
)、hidden_size
(此处为20
)和num_layers
(此处为2
)。
另一方面 rnn(input, (h0, c0))
对应于实际调用 class 实例, i.e.
运行 __call__
大致相当于 forward
该模块的功能。 nn.LSTM
的 __call__
方法接受两个参数:input
(形 (sequnce_length, batch_size, input_size)
和两个张量 (h_0, c_0)
的元组(形如 (num_layers, batch_size, hidden_size)
在nn.LSTM
)
请在使用内置函数时参考 PyTorch 文档,您会找到参数列表的确切定义(用于初始化 class 实例的参数)以及 input/outputs 规范(每当用那个模块进行推断时)。
您可能对符号感到困惑,这里有一个小例子可以帮助您:
元组作为输入:
def fn1(x, p): a, b = p # unpack input return a*x + b >>> fn1(2, (3, 1)) >>> 7
元组作为输出
def fn2(x): return x, (3*x, x**2) # actually output is a tuple of int and tuple >>> x, (a, b) = fn2(2) # unpacking (2, (6, 4)) >>> x, a, b (2, 6, 4)