BigBird,或稀疏自注意力:如何实现稀疏矩阵?

BigBird, or Sparse self-attention: How to implement a sparse matrix?

此问题与新论文相关:Big Bird: Transformers for Longer Sequences. Mainly, about the implementation of the Sparse Attention (that is specified in the Supplemental material, part D)。目前,我正在尝试在 PyTorch 中实现它。

他们提出了一种通过阻止原始查询和密钥矩阵来加速计算的新方法(见下文)

当你在步骤 (b) 中进行矩阵乘积时,你最终会得到类似这样的结果: .

所以我想知道:你如何从该表示(上图)到稀疏矩阵(使用 PyTorch,见下文)? 在论文中,他们只是说:“简单地重塑结果”,我不知道有什么简单的方法可以做到这一点(尤其是当我在不同位置有多个块时(参见第一张图片上的步骤(c) ).

解决方案: Huggingface 在 pytorch 中实现了 BigBird。

我最终遵循了本文中的指南。当涉及到结果的解包时,我使用:torch.sparse_coo_tensor

编辑:稀疏张量仍然需要大量内存!这里描述了