对 PyTorch GRU 文档感到困惑

Confused Regarding PyTorch GRU Docs

这可能是一个太基础的问题,但是文档中 GRU 的输入需要 3 维是什么意思? PyTorch 状态的 GRU 文档:

input of shape (seq_len, batch, input_size): tensor containing the features of the input sequence. The input can also be a packed variable length sequence. See torch.nn.utils.rnn.pack_padded_sequence() for details.

https://pytorch.org/docs/stable/generated/torch.nn.GRU.html

假设我正在尝试预测序列中的下一个 # 并具有以下数据集:

n, label
1, 2
2, 3
3, 6
4, 9
...

如果我 window 在猜测下一个时考虑使用前 2 个输入的数据,则数据集变为:

t-2, t-1, t, label
na, na, 1, 2
na, 1, 2, 3
1, 2, 3, 6
2, 3, 4, 10
...

其中 t-x 仅表示使用来自先前时间步长的输入值。

因此,在创建顺序加载程序时,应该为第 1、2、3、6 行创建以下张量:

inputs: tensor([[1,2,3]]) #shape(1,3)
labels: tensor([[6]])     #shape(1,1)

我目前将输入形状理解为(# batches,# features per batch) 输出形状为(# batches,# output features per batch)

我的问题是,输入张量应该像这样:

tensor([[[1],[2],[3]]])

代表(#批次,#prior inputs to consider,#features per input)

我想我最好尝试理解为什么 GRU 的输入在 PyTorch 中有 3 个维度。第三维度从根本上代表什么?如果我有一个像上面这样的转换数据集,如何将它正确地传递给模型。

编辑: 所以现在的模式是:

1 + 1 = 2
2 + 1 = 3
3 + 2 + 1 = 6
4+ 3 + 2 + 1 = 10

我想要它,其中 t-2、t-1 和 t 代表用于帮助猜测的每个时间步长的特征。例如,在每个时间点可能有 2 个特征。尺寸为(1 个批量大小、3 个时间步长、2 个特征)。

我的问题是 GRU 是否采用扁平化输入:

(1 batch size, 3 time steps * 2 features per time step)

或未展平的输入:

(1 batch size, 3 time steps, 2 features per timestep)

我目前的印象是它是第二个输入,但想检查一下我的理解。

nn.GRU 模块与其他 PyTorch RNN 模块一样工作。如果参数 batch_first 设置为 True,它需要一个三维张量 (seq_len, batch, input_size)(batch, seq_len, input_size)。我认为最后一个维度是困扰你的问题。

你解释说你的序列设置如下:

t-2 t-1 t label
na na 1 2
na 1 2 3
1 2 3 6
2 3 4 10

您缺少的是输入编码:您将如何表示您的预测和标签?像这样输入整数是行不通的。您可能需要将数据转换为单热编码。

假设有 10 个不同的标签,即您的词汇表由 10 个元素组成。转换为单热编码是一个直接的过程。取一个长度为词汇表大小的零向量,并将 1 放在与特定标签对应的索引处。

词汇量又会是……input_size。给定一个 label,这看起来像:

encoding = torch.zeros(input_size)
encoding[label] = 1
label one-hot-encoding
0 [1,0,0,..., 0, 0]
1 [0,1,0,..., 0, 0]
... ...
9 [0,0,0,..., 0, 1]

因此,你的训练点(输入序列1, 2, 3标签6)将翻译到(输入序列[[0,1,0,0,0,0,0,0,0,0], [0,0,1,0,0,0,0,0,0,0], [0,0,1,0,0,0,0,0,0,0]]标签6)。这是二维的,为批处理添加了额外的维度(见第一节),这就是三。

我故意保留标签,因为 PyTorch 损失函数(例如 nn.CrossEntropyLoss)通常需要索引而不是目标的单热编码(//标签)。

我想通了。本质上,序列长度为 3 意味着输入到系统需要是:[[[1],[2],[3]], [[2], [3], [4]]] 对于一个 batch大小为 2,序列长度为 3,每个时间步的特征输入为 1。本质上,每个序列都是在某个时间 t 考虑的输入。