Pytorch nn.embedding 错误
Pytorch nn.embedding error
我正在阅读 Word Embedding 上的 pytorch 文档。
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
torch.manual_seed(5)
word_to_ix = {"hello": 0, "world": 1, "how":2, "are":3, "you":4}
embeds = nn.Embedding(2, 5) # 2 words in vocab, 5 dimensional embeddings
lookup_tensor = torch.tensor(word_to_ix["hello"], dtype=torch.long)
hello_embed = embeds(lookup_tensor)
print(hello_embed)
输出:
tensor([-0.4868, -0.6038, -0.5581, 0.6675, -0.1974])
这看起来不错,但如果我将行 lookup_tensor 替换为
lookup_tensor = torch.tensor(word_to_ix["how"], dtype=torch.long)
我得到的错误是:
RuntimeError: index out of range at /Users/soumith/minicondabuild3/conda-bld/pytorch_1524590658547/work/aten/src/TH/generic/THTensorMath.c:343
我不明白为什么它在第 hello_embed = embeds(lookup_tensor)
行给出运行时错误。
当您声明 embeds = nn.Embedding(2, 5)
时,vocab 大小为 2,嵌入大小为 5。即每个单词将由大小为 5 的向量表示,vocab 中只有 2 个单词。
lookup_tensor = torch.tensor(word_to_ix["how"], dtype=torch.long)
embeds 将尝试查找与 vocab 中的第三个单词对应的向量,但嵌入的 vocab 大小为 2。这就是你得到错误的原因。
如果您声明 embeds = nn.Embedding(5, 5)
它应该可以正常工作。
我正在阅读 Word Embedding 上的 pytorch 文档。
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
torch.manual_seed(5)
word_to_ix = {"hello": 0, "world": 1, "how":2, "are":3, "you":4}
embeds = nn.Embedding(2, 5) # 2 words in vocab, 5 dimensional embeddings
lookup_tensor = torch.tensor(word_to_ix["hello"], dtype=torch.long)
hello_embed = embeds(lookup_tensor)
print(hello_embed)
输出:
tensor([-0.4868, -0.6038, -0.5581, 0.6675, -0.1974])
这看起来不错,但如果我将行 lookup_tensor 替换为
lookup_tensor = torch.tensor(word_to_ix["how"], dtype=torch.long)
我得到的错误是:
RuntimeError: index out of range at /Users/soumith/minicondabuild3/conda-bld/pytorch_1524590658547/work/aten/src/TH/generic/THTensorMath.c:343
我不明白为什么它在第 hello_embed = embeds(lookup_tensor)
行给出运行时错误。
当您声明 embeds = nn.Embedding(2, 5)
时,vocab 大小为 2,嵌入大小为 5。即每个单词将由大小为 5 的向量表示,vocab 中只有 2 个单词。
lookup_tensor = torch.tensor(word_to_ix["how"], dtype=torch.long)
embeds 将尝试查找与 vocab 中的第三个单词对应的向量,但嵌入的 vocab 大小为 2。这就是你得到错误的原因。
如果您声明 embeds = nn.Embedding(5, 5)
它应该可以正常工作。