为什么我们要做批量矩阵矩阵产品?

Why do we do batch matrix-matrix product?

我正在关注 Pytorch seq2seq tutorial,它torch.bmm 方法的用法如下:

attn_applied = torch.bmm(attn_weights.unsqueeze(0),
                         encoder_outputs.unsqueeze(0))

我明白为什么我们需要将注意力权重和编码器输出相乘。

我不太明白的是为什么我们这里需要bmm方法。 torch.bmm 文档说

Performs a batch matrix-matrix product of matrices stored in batch1 and batch2.

batch1 and batch2 must be 3-D tensors each containing the same number of matrices.

If batch1 is a (b×n×m) tensor, batch2 is a (b×m×p) tensor, out will be a (b×n×p) tensor.

上图中描述的操作发生在Seq2Seq模型的Decoder端。这意味着 编码器输出 已经在批次方面(具有 小批量大小 样本)。因此,attn_weights 张量也应该处于批处理模式。

因此,本质上,张量 attn_weightsencoder_outputs 的第一个维度(NumPy 术语中的第 zero 个轴)是 个样本数小批量大小。因此,我们需要 torch.bmm 这两个张量。

在 seq2seq 模型中,编码器将输入序列编码为小批量。例如,输入是 B x S x d,其中 B 是批量大小,S 是最大序列长度,d 是词嵌入维数。然后编码器的输出是 B x S x h 其中 h 是编码器(它是一个 RNN)的隐藏状态大小。

现在正在解码(训练期间) 一次给定一个输入序列,因此输入为B x 1 x d,解码器产生形状为B x 1 x h的张量。现在要计算上下文向量,我们需要将解码器的隐藏状态与编码器的编码状态进行比较。

因此,假设您有两个形状为 T1 = B x S x hT2 = B x 1 x h 的张量。因此,如果您可以按如下方式进行批量矩阵乘法。

out = torch.bmm(T1, T2.transpose(1, 2))

本质上,你是将形状为 B x S x h 的张量与形状为 B x h x 1 的张量相乘,结果将是 B x S x 1,这是每个批次的注意力权重。

这里,注意力权重B x S x 1表示解码器当前隐藏状态与编码器所有隐藏状态之间的相似度得分。现在,您可以先通过转置将注意力权重与编码器的隐藏状态 B x S x h 相乘,这将产生形状为 B x h x 1 的张量。如果你在 dim=2 时执行挤压,你将得到一个形状为 B x h 的张量,这是你的上下文向量。

这个上下文向量(B x h)通常连接到解码器的隐藏状态(B x 1 x h,squeeze dim=1)以预测下一个标记。

虽然 @wasiahmad 关于 seq2seq 的一般实现是正确的,但在提到的教程中没有批处理 (B=1),bmm 只是过度工程,可以安全地替换为 matmul 具有完全相同的模型质量和性能。自己看看,替换这个:

        attn_applied = torch.bmm(attn_weights.unsqueeze(0),
                                 encoder_outputs.unsqueeze(0))
        output = torch.cat((embedded[0], attn_applied[0]), 1)

有了这个:

        attn_applied = torch.matmul(attn_weights,
                                 encoder_outputs)
        output = torch.cat((embedded[0], attn_applied), 1)

和 运行 笔记本。


此外,请注意,虽然@wasiahmad 将编码器输入称为 B x S x d,但在 pytorch 1.7.0 中,作为编码器主引擎的 GRU 期望输入格式为 (seq_len, batch, input_size) 默认。如果您想使用@wasiahmad 格式,请传递 batch_first = True 标志。