尝试使用 Pytorch 和 Tensorflow MultiheadAttention 获得相同的结果

Trying to achieve same result with Pytorch and Tensorflow MultiheadAttention

我正在尝试重新创建一个用 Pytorch 编写的转换器并在 Tensorflow 中实现它。问题是,尽管 Pytorch version and Tensorflow 版本的文档都有,但它们的结果仍然大不相同。 我写了一小段代码来说明问题:

import torch
import tensorflow as tf
import numpy as np

class TransformerLayer(tf.Module):
    def __init__(self, d_model, nhead, dropout=0):
        super(TransformerLayer, self).__init__()
        self.self_attn = torch.nn.MultiheadAttention(d_model, nhead, dropout=dropout)

batch_size = 2
seq_length = 5
d_model = 10

src = np.random.uniform(size=(batch_size, seq_length, d_model))
srcTF = tf.convert_to_tensor(src)
srcPT = torch.Tensor(src.reshape((seq_length, batch_size, d_model)))

self_attnTF = tf.keras.layers.MultiHeadAttention(key_dim=10, num_heads=5, dropout=0)
transformer_encoder = TransformerLayer(d_model=10, nhead=5, dropout=0.0)

output, scores = self_attnTF(srcTF, srcTF, srcTF, return_attention_scores=True)
print("Tensorflow Attendtion outputs:", output)
print("Tensorflow (averaged) weights:", tf.math.reduce_mean(scores, 1))
print("Torch Attendtion outputs:", transformer_encoder.self_attn(srcPT,srcPT,srcPT)[0])
print("Torch attention output weights:", transformer_encoder.self_attn(srcPT,srcPT,srcPT)[1])

结果是:

Tensorflow Attendtion outputs: tf.Tensor(
[[[ 0.02602757 -0.14134401  0.00855263  0.4735083  -0.01851891
   -0.20382246 -0.18152176 -0.21076852  0.08623976 -0.33548725]
  [ 0.02607442 -0.1403394   0.00814065  0.47415024 -0.01882939
   -0.20353754 -0.18291879 -0.21234266  0.08595885 -0.33613583]
  [ 0.02524654 -0.14096384  0.00870436  0.47411725 -0.01800703
   -0.20486829 -0.18163288 -0.21082559  0.08571021 -0.3362339 ]
  [ 0.02518575 -0.14039244  0.0090138   0.47431853 -0.01775141
   -0.20391947 -0.18138805 -0.2118245   0.08432849 -0.33521986]
  [ 0.02556361 -0.14039293  0.00876258  0.4746476  -0.01891363
   -0.20398234 -0.18229616 -0.21147579  0.08555281 -0.33639923]]

 [[ 0.07844199 -0.1614371   0.01649148  0.5287745   0.05126739
   -0.13851154 -0.09829871 -0.1621251   0.01922669 -0.2428589 ]
  [ 0.07844222 -0.16024739  0.01805423  0.52941847  0.04975721
   -0.13537636 -0.09829231 -0.16129729  0.01979005 -0.24491176]
  [ 0.07800542 -0.160701    0.01677295  0.52902794  0.05082911
   -0.13843337 -0.09805533 -0.16165744  0.01928401 -0.24327613]
  [ 0.07815789 -0.1600025   0.01757433  0.5291927   0.05032986
   -0.1368022  -0.09849522 -0.16172451  0.01929555 -0.24438493]
  [ 0.0781548  -0.16028519  0.01764914  0.52846324  0.04941286
   -0.13746066 -0.09787872 -0.16141161  0.01994199 -0.2440269 ]]], shape=(2, 5, 10), dtype=float32)
Tensorflow (averaged) weights: tf.Tensor(
[[[0.199085   0.20275716 0.20086522 0.19873264 0.19856   ]
  [0.2015336  0.19960018 0.20218948 0.19891861 0.19775811]
  [0.19906266 0.20318432 0.20190334 0.19812575 0.19772394]
  [0.20074987 0.20104568 0.20269363 0.19744729 0.19806348]
  [0.19953248 0.20176074 0.20314851 0.19782843 0.19772986]]

 [[0.2010009  0.20053487 0.20004745 0.20092985 0.19748697]
  [0.20034568 0.20035927 0.19955876 0.20062163 0.19911464]
  [0.19967113 0.2006859  0.20012529 0.20047483 0.19904283]
  [0.20132652 0.19996871 0.20019794 0.20008174 0.19842513]
  [0.2006393  0.20000939 0.19938737 0.20054278 0.19942114]]], shape=(2, 5, 5), dtype=float32)
Torch Attendtion outputs: tensor([[[ 0.1097, -0.4467, -0.0719, -0.1779, -0.0766, -0.1247,  0.1557,
           0.0051, -0.3932, -0.1323],
         [ 0.1264, -0.3822,  0.0759, -0.0335, -0.1084, -0.1539,  0.1475,
          -0.0272, -0.4235, -0.1744]],

        [[ 0.1122, -0.4502, -0.0747, -0.1796, -0.0756, -0.1271,  0.1581,
           0.0049, -0.3964, -0.1340],
         [ 0.1274, -0.3823,  0.0754, -0.0356, -0.1091, -0.1547,  0.1477,
          -0.0272, -0.4252, -0.1752]],

        [[ 0.1089, -0.4427, -0.0728, -0.1746, -0.0756, -0.1202,  0.1501,
           0.0031, -0.3894, -0.1242],
         [ 0.1263, -0.3820,  0.0718, -0.0374, -0.1063, -0.1562,  0.1485,
          -0.0271, -0.4233, -0.1761]],

        [[ 0.1061, -0.4369, -0.0685, -0.1696, -0.0772, -0.1173,  0.1454,
           0.0012, -0.3860, -0.1201],
         [ 0.1265, -0.3820,  0.0762, -0.0325, -0.1082, -0.1560,  0.1501,
          -0.0271, -0.4249, -0.1779]],

        [[ 0.1043, -0.4402, -0.0705, -0.1719, -0.0791, -0.1205,  0.1508,
           0.0018, -0.3895, -0.1262],
         [ 0.1260, -0.3805,  0.0775, -0.0298, -0.1083, -0.1547,  0.1494,
          -0.0276, -0.4242, -0.1768]]], grad_fn=<AddBackward0>)
Torch attention output weights: tensor([[[0.2082, 0.2054, 0.1877, 0.1956, 0.2031],
         [0.2100, 0.2079, 0.1841, 0.1943, 0.2037],
         [0.2007, 0.1995, 0.1929, 0.1999, 0.2070],
         [0.1995, 0.1950, 0.1976, 0.2002, 0.2077],
         [0.1989, 0.1969, 0.1970, 0.2024, 0.2048]],

        [[0.2095, 0.1902, 0.1987, 0.2027, 0.1989],
         [0.2090, 0.1956, 0.1997, 0.2004, 0.1952],
         [0.2047, 0.1869, 0.2006, 0.2121, 0.1957],
         [0.2073, 0.1953, 0.1982, 0.2014, 0.1978],
         [0.2089, 0.2003, 0.1953, 0.1957, 0.1998]]], grad_fn=<DivBackward0>)

输出权重看起来相似,但基本注意输出相差甚远。有什么方法可以让 Tensorflow 模型更像 Pytorch 模型吗?任何帮助将不胜感激!

Mu​​ltiHeadAttention中还有一个投影层,比如

Q = W_q @ input_query + b_q
K = W_k @ input_keys + b_k
V = W_v @ input_values + b_v

矩阵W_qW_kW_v以及偏置b_qb_kb_v是随机初始化的,所以输出差异应该是预期的(即使在同一输入的 pytorch 中两个不同层的输出之间)。 self-attention 操作后还有一个投影,它也是随机初始化的。在tensorflow中可以通过调用self_attnTF.

的方法set_weights手动设置权重

tf.keras.layers.MultiHeadAttentionnn.MultiheadAttention 中的权重之间的对应关系不是很清楚,例如:torch 在 head 之间共享权重,而 tf 使它们保持唯一。因此,如果您使用来自 pytorch 的预训练模型的权重并尝试将它们放入 tensorflow 模型(无论出于何种原因),它肯定需要超过五分钟。

如果在初始化 pytorch 模型和 tensorflow 模型后逐步检查它们的参数并为它们分配相同的值,结果应该是相同的。