相当于torch EmbeddingBag

Equivalent of torch EmbeddingBag

Torch 声称带有 mode="sum" 的 EmbeddingBag 等同于 Embedding 后接 torch.sum(dim=1),但我该如何具体实现呢?假设我们有 “EE = nn.EmbeddingBag(n, m, mode="sum", sparse=True)”,我们如何用“nn.Embeeding”和“[=12”替换“nn.EmbeddingBag” =]”等同于?非常感谢

考虑以下示例,其中所有四种实现都产生相同的结果:

  • nn.EmbeddingBag:

    >>> embedding_sum = nn.EmbeddingBag(10, 3, mode='sum')
    >>> embedding_sum(input, torch.zeros(1).long())
    
  • nn.functional.embedding_bag:

    >>> F.embedding_bag(input, embedding_sum.weight, torch.zeros(1).long(), mode='sum')
    
  • nn.Embedding:

    >>> embedding = nn.Embedding(10, 3)
    >>> embedding.weight = embedding_sum.weight
    >>> embedding(input).sum(0)
    
  • nn.functional.embedding:

    >>> F.embedding(input, embedding_sum.weight).sum(0)