使用来自 gensim 的 pre_trained 向量对 torch 嵌入层的预期输入

Expected input to torch Embedding layer with pre_trained vectors from gensim

我想在我的神经网络架构中使用预训练嵌入。预训练嵌入由 gensim 训练。我发现 这表明我们可以像这样加载 pre_trained 模型:

import gensim
from torch import nn

model = gensim.models.KeyedVectors.load_word2vec_format('path/to/file')
weights = torch.FloatTensor(model.vectors)
emb = nn.Embedding.from_pretrained(torch.FloatTensor(weights.vectors))

这似乎在 1.0.1 上也能正常工作。我的问题是,我不太明白我必须将什么输入到这样的层中才能使用它。我可以只喂令牌(分段句子)吗?我是否需要映射,例如令牌到索引?

我发现你可以简单地通过类似

的方式访问令牌的向量
print(weights['the'])
# [-1.1206588e+00  1.1578362e+00  2.8765252e-01 -1.1759659e+00 ... ]

这对 RNN 架构意味着什么?我们可以简单地加载批处理序列的标记吗?例如:

for seq_batch, y in batch_loader():
    # seq_batch is a batch of sequences (tokenized sentences)
    # e.g. [['i', 'like', 'cookies'],['it', 'is', 'raining'],['who', 'are', 'you']]
    output, hidden = model(seq_batch, hidden)

这似乎不起作用,所以我假设您需要将标记转换为其在最终 word2vec 模型中的索引。真的吗?我发现您可以使用 word2vec 模型的 vocab:

获取单词的索引
weights.vocab['world'].index
# 147

那么作为嵌入层的输入,我是否应该为由单词序列组成的句子序列提供 int 张量?与虚拟数据加载器(参见上面的示例)和虚拟 RNN 欢迎使用的示例。

documentation 表示如下

This module is often used to store word embeddings and retrieve them using indices. The input to the module is a list of indices, and the output is the corresponding word embeddings.

所以如果你想输入一个句子,你给一个 LongTensor of 索引,每个索引对应词汇表中的一个词,nn.Embedding 层将映射到词向量中。

这是一个插图

test_voc = ["ok", "great", "test"]
# The word vectors for "ok", "great" and "test"
# are at indices, 0, 1 and 2, respectively.

my_embedding = torch.rand(3, 50)
e = nn.Embedding.from_pretrained(my_embedding)

# LongTensor of indicies corresponds to a sentence,
# reshaped to (1, 3) because batch size is 1
my_sentence = torch.tensor([0, 2, 1]).view(1, -1)

res = e(my_sentence)
print(res.shape)
# => torch.Size([1, 3, 50])
# 1 is the batch dimension, and there's three vectors of length 50 each

就 RNN 而言,接下来您可以将该张量输入到您的 RNN 模块中,例如

lstm = nn.LSTM(input_size=50, hidden_size=5, batch_first=True)
output, h = lstm(res)
print(output.shape)
# => torch.Size([1, 3, 5])

我还建议您查看 torchtext。它可以自动执行一些您必须手动执行的操作。