验证 Multihead Attention 在 Transformer 中的实现

Verifying the implementation of Multihead Attention in Transformer

我已经在 Transformers 中实现了 MultiAttention head。周围有太多的实现,所以很混乱。有人可以验证我的实现是否正确吗:

DotProductAttention 引用自:https://www.tensorflow.org/tutorials/text/transformer#setup

import tensorflow as tf

def scaled_dot_product(q,k,v):
    #calculates Q . K(transpose)
    qkt = tf.matmul(q,k,transpose_b=True)
    #caculates scaling factor
    dk = tf.math.sqrt(tf.cast(q.shape[-1],dtype=tf.float32))
    scaled_qkt = qkt/dk
    softmax = tf.nn.softmax(scaled_qkt,axis=-1)
    
    z = tf.matmul(softmax,v)
    #shape: (m,Tx,depth), same shape as q,k,v
    return z

class MultiAttention(tf.keras.layers.Layer):
    def __init__(self,d_model,num_of_heads):
        super(MultiAttention,self).__init__()
        self.d_model = d_model
        self.num_of_heads = num_of_heads
        self.depth = d_model//num_of_heads
        self.wq = [tf.keras.layers.Dense(self.depth) for i in range(num_of_heads)]
        self.wk = [tf.keras.layers.Dense(self.depth) for i in range(num_of_heads)]
        self.wv = [tf.keras.layers.Dense(self.depth) for i in range(num_of_heads)]
        self.wo = tf.keras.layers.Dense(d_model)
        
    def call(self,x):
        
        multi_attn = []
        for i in range(self.num_of_heads):
            Q = self.wq[i](x)
            K = self.wk[i](x)
            V = self.wv[i](x)
            multi_attn.append(scaled_dot_product(Q,K,V))
            
        multi_head = tf.concat(multi_attn,axis=-1)
        multi_head_attention = self.wo(multi_head)
        return multi_head_attention

#Calling the attention 
multi = MultiAttention(d_model=512,num_of_heads=8)
m = 5; sequence_length = 4; word_embedding_dim = 512
sample_ip = tf.constant(tf.random.normal(shape=(m,sequence_length,word_embedding_dim)))
attn =multi(sample_ip)
#shape of op (attn): (5,4,512)

在您的实施中,在 scaled_dot_product 中您使用 query 进行缩放,但根据原始论文,他们使用 key 进行归一化。除此之外,这个实现看起来还可以,但不一般。

class MultiAttention(tf.keras.layers.Layer):
    def __init__(self, num_of_heads, out_dim):
        super(MultiAttention,self).__init__()
        self.out_dim      = out_dim
        self.num_of_heads = num_of_heads
        self.depth        = self.out_dim // self.num_of_heads
        self.wq = [tf.keras.layers.Dense(self.depth) for i in range(num_of_heads)]
        self.wk = [tf.keras.layers.Dense(self.depth) for i in range(num_of_heads)]
        self.wv = [tf.keras.layers.Dense(self.depth) for i in range(num_of_heads)]
        self.wo = tf.keras.layers.Dense(self.out_dim)
        
    def call(self,x):
        multi_attn = []
        for i in range(self.num_of_heads):
            Q = self.wq[i](x)
            K = self.wk[i](x)
            V = self.wv[i](x)
            multi_attn.append(self.scaled_dot_product(Q,K,V))

        multi_head = tf.concat(multi_attn, axis=-1)
        multi_head_attention = self.wo(multi_head)
        return multi_head_attention

    def scaled_dot_product(self, q,k,v):
        qkt = tf.matmul(q, k, transpose_b=True)
        dk = tf.math.sqrt( tf.cast(k.shape[-1], dtype=tf.float32) )
        scaled_qkt = qkt/dk
        softmax = tf.nn.softmax(scaled_qkt, axis=-1)
        z = tf.matmul(softmax, v)
        return z

multi = MultiAttention(num_of_heads=3, out_dim=32)
sample_ip = tf.random.normal(shape=(2, 2, 32)); print(sample_ip.shape)
multi(sample_ip).shape

一般的transformer架构可以演示如下,其中前两个线性层代表querykey并负责产生注意权重映射,然后以矩阵乘法方式对value进行加权。

Image Source.

我知道您正在尝试最小化原始 TF tutorial code,但我认为您应该首先添加对原始问题的引用。在最初的实现中,他们还返回 加权概率或分数 以及 加权特征图 。我认为你不应该跳过那个。


您关注的 original code 更通用、更有效地进行了优化。

class MultiHeadAttention(tf.keras.layers.Layer):
    def __init__(self, d_model, num_heads):
        super(MultiHeadAttention, self).__init__()
        self.num_heads = num_heads
        self.d_model = d_model
        assert d_model % self.num_heads == 0
        self.depth = d_model // self.num_heads
        self.wq = tf.keras.layers.Dense(d_model)
        self.wk = tf.keras.layers.Dense(d_model)
        self.wv = tf.keras.layers.Dense(d_model)
        self.dense = tf.keras.layers.Dense(d_model)

    def scaled_dot_product_attention(self, q, k, v, mask=None):
        matmul_qk = tf.matmul(q, k, transpose_b=True)  # (..., seq_len_q, seq_len_k)
        # scale matmul_qk
        dk = tf.cast(tf.shape(k)[-1], tf.float32)
        scaled_attention_logits = matmul_qk / tf.math.sqrt(dk)
        # add the mask to the scaled tensor.
        if mask is not None: scaled_attention_logits += (mask * -1e9)
        # softmax is normalized on the last axis (seq_len_k) so that the scores
        # add up to 1.
        attention_weights = tf.nn.softmax(scaled_attention_logits, axis=-1)  # (..., seq_len_q, seq_len_k)
        output = tf.matmul(attention_weights, v)  # (..., seq_len_q, depth_v)
        return output, attention_weights

    def split_heads(self, x, batch_size):
        """Split the last dimension into (num_heads, depth).
        Transpose the result such that the shape is (batch_size, num_heads, seq_len, depth)
        """
        x = tf.reshape(x, (batch_size, -1, self.num_heads, self.depth))
        return tf.transpose(x, perm=[0, 2, 1, 3])

    def call(self, v, k, q, mask=None):
        batch_size = tf.shape(q)[0]
        q = self.wq(q)  # (batch_size, seq_len, d_model)
        k = self.wk(k)  # (batch_size, seq_len, d_model)
        v = self.wv(v)  # (batch_size, seq_len, d_model)

        q = self.split_heads(q, batch_size)  # (batch_size, num_heads, seq_len_q, depth)
        k = self.split_heads(k, batch_size)  # (batch_size, num_heads, seq_len_k, depth)
        v = self.split_heads(v, batch_size)  # (batch_size, num_heads, seq_len_v, depth)
        # scaled_attention.shape == (batch_size, num_heads, seq_len_q, depth)
        # attention_weights.shape == (batch_size, num_heads, seq_len_q, seq_len_k)
        scaled_attention, attention_weights = self.scaled_dot_product_attention(q, k, v, mask)

        scaled_attention = tf.transpose(scaled_attention, perm=[0, 2, 1, 3])  # (batch_size, seq_len_q, num_heads, depth)
        concat_attention = tf.reshape(scaled_attention,  (batch_size, -1, self.d_model))  # (batch_size, seq_len_q, d_model)
        output = self.dense(concat_attention)  # (batch_size, seq_len_q, d_model)
        return output, attention_weights

仅供参考,在TF 2.4中,正式添加了tf.keras.layers.MultiHeadAttention层。

layer = tf.keras.layers.MultiHeadAttention(num_heads=2, key_dim=2)
input_tensor = tf.keras.Input(shape=[2, 2, 32]); print(input_tensor.shape)
print(layer(input_tensor, input_tensor).shape)

您可以按如下方式测试这两个:

# custom layer MHA
multi = MultiHeadAttention(d_model=512, num_heads=2)
y = tf.random.uniform((1, 60, 512))  
out, attn = multi(y, k=y, q=y, mask=None)
out.shape, attn.shape
(TensorShape([1, 60, 512]), TensorShape([1, 2, 60, 60]))

# built-in layer 
layer = tf.keras.layers.MultiHeadAttention(num_heads=2, key_dim=2)
y = tf.random.uniform((1, 60, 512))  
out, attn = layer(y, y, return_attention_scores=True)
out.shape, attn.shape
(TensorShape([1, 60, 512]), TensorShape([1, 2, 60, 60]))