相当于torch EmbeddingBag
Equivalent of torch EmbeddingBag
Torch 声称带有 mode="sum" 的 EmbeddingBag 等同于 Embedding 后接 torch.sum(dim=1),但我该如何具体实现呢?假设我们有
“EE = nn.EmbeddingBag(n, m, mode="sum", sparse=True)”,我们如何用“nn.Embeeding”和“[=12”替换“nn.EmbeddingBag” =]”等同于?非常感谢
考虑以下示例,其中所有四种实现都产生相同的结果:
-
>>> embedding_sum = nn.EmbeddingBag(10, 3, mode='sum')
>>> embedding_sum(input, torch.zeros(1).long())
-
>>> F.embedding_bag(input, embedding_sum.weight, torch.zeros(1).long(), mode='sum')
-
>>> embedding = nn.Embedding(10, 3)
>>> embedding.weight = embedding_sum.weight
>>> embedding(input).sum(0)
-
>>> F.embedding(input, embedding_sum.weight).sum(0)
Torch 声称带有 mode="sum" 的 EmbeddingBag 等同于 Embedding 后接 torch.sum(dim=1),但我该如何具体实现呢?假设我们有 “EE = nn.EmbeddingBag(n, m, mode="sum", sparse=True)”,我们如何用“nn.Embeeding”和“[=12”替换“nn.EmbeddingBag” =]”等同于?非常感谢
考虑以下示例,其中所有四种实现都产生相同的结果:
-
>>> embedding_sum = nn.EmbeddingBag(10, 3, mode='sum') >>> embedding_sum(input, torch.zeros(1).long())
-
>>> F.embedding_bag(input, embedding_sum.weight, torch.zeros(1).long(), mode='sum')
-
>>> embedding = nn.Embedding(10, 3) >>> embedding.weight = embedding_sum.weight >>> embedding(input).sum(0)
-
>>> F.embedding(input, embedding_sum.weight).sum(0)