带有 torchtext 的 Pytorch LSTM 输入维度有问题

Having trouble with input dimensions for Pytorch LSTM with torchtext

问题

我正在尝试使用 LSTM 构建文本分类器网络。我得到的错误是:

RuntimeError: Expected hidden[0] size (4, 600, 256), got (4, 64, 256)

详情

数据是 json,看起来像这样:

{"cat": "music", "desc": "I'm in love with the song's intro!", "sent": "h"}

我正在使用 torchtext 加载数据。

from torchtext import data
from torchtext import datasets

TEXT = data.Field(fix_length = 600)
LABEL = data.Field(fix_length = 10)

BATCH_SIZE = 64

fields = {
    'cat': ('c', LABEL),
    'desc': ('d', TEXT),
    'sent': ('s', LABEL),
}

我的 LSTM 看起来像这样

EMBEDDING_DIM = 64
HIDDEN_DIM = 256
N_LAYERS = 4

MyLSTM(
  (embedding): Embedding(11967, 64)
  (lstm): LSTM(64, 256, num_layers=4, batch_first=True, dropout=0.5)
  (dropout): Dropout(p=0.3, inplace=False)
  (fc): Linear(in_features=256, out_features=8, bias=True)
  (sig): Sigmoid()
)

我最终得到 inputslabels

的以下维度
batch = list(train_iterator)[0]
inputs, labels = batch
print(inputs.shape) # torch.Size([600, 64])
print(labels.shape) # torch.Size([100, 2, 64])

我初始化的隐藏张量看起来像:

hidden # [torch.Size([4, 64, 256]), torch.Size([4, 64, 256])]

问题

我正在尝试了解每一步的尺寸应该是多少。 隐藏维度应该初始化为(4, 600, 256) 还是(4, 64, 256)?

nn.LSTM - Inputs 的文档解释了维度是什么:

  • h_0 of shape (num_layers * num_directions, batch, hidden_size): tensor containing the initial hidden state for each element in the batch. If the LSTM is bidirectional, num_directions should be 2, else it should be 1.

因此,您的隐藏状态的大小应为 (4, 64, 256),所以您做对了。另一方面,您没有提供正确的输入尺寸。

  • input of shape (seq_len, batch, input_size): tensor containing the features of the input sequence. The input can also be a packed variable length sequence. See torch.nn.utils.rnn.pack_padded_sequence() or torch.nn.utils.rnn.pack_sequence() for details.

虽然它说输入的大小需要 (seq_len, batch, input_size),但你设置了 batch_first=True 在你的 LSTM 中,它交换 batchseq_len。因此,您的输入应具有大小 (batch_size, seq_len, input_size),但事实并非如此,因为您的输入具有 seq_len 第一个 (600) 和 batch 第二个 (64),这是 torchtext 中的默认值,因为这是更常见的表示形式,也与默认值匹配LSTM 的行为。

您需要在您的 LSTM 中设置 batch_first=False

或者。如果您更喜欢将 batch 作为一般的第一个维度,torch.data.Field 也有 batch_first 选项。