验证 Multihead Attention 在 Transformer 中的实现
Verifying the implementation of Multihead Attention in Transformer
我已经在 Transformers
中实现了 MultiAttention head
。周围有太多的实现,所以很混乱。有人可以验证我的实现是否正确吗:
DotProductAttention 引用自:https://www.tensorflow.org/tutorials/text/transformer#setup
import tensorflow as tf
def scaled_dot_product(q,k,v):
#calculates Q . K(transpose)
qkt = tf.matmul(q,k,transpose_b=True)
#caculates scaling factor
dk = tf.math.sqrt(tf.cast(q.shape[-1],dtype=tf.float32))
scaled_qkt = qkt/dk
softmax = tf.nn.softmax(scaled_qkt,axis=-1)
z = tf.matmul(softmax,v)
#shape: (m,Tx,depth), same shape as q,k,v
return z
class MultiAttention(tf.keras.layers.Layer):
def __init__(self,d_model,num_of_heads):
super(MultiAttention,self).__init__()
self.d_model = d_model
self.num_of_heads = num_of_heads
self.depth = d_model//num_of_heads
self.wq = [tf.keras.layers.Dense(self.depth) for i in range(num_of_heads)]
self.wk = [tf.keras.layers.Dense(self.depth) for i in range(num_of_heads)]
self.wv = [tf.keras.layers.Dense(self.depth) for i in range(num_of_heads)]
self.wo = tf.keras.layers.Dense(d_model)
def call(self,x):
multi_attn = []
for i in range(self.num_of_heads):
Q = self.wq[i](x)
K = self.wk[i](x)
V = self.wv[i](x)
multi_attn.append(scaled_dot_product(Q,K,V))
multi_head = tf.concat(multi_attn,axis=-1)
multi_head_attention = self.wo(multi_head)
return multi_head_attention
#Calling the attention
multi = MultiAttention(d_model=512,num_of_heads=8)
m = 5; sequence_length = 4; word_embedding_dim = 512
sample_ip = tf.constant(tf.random.normal(shape=(m,sequence_length,word_embedding_dim)))
attn =multi(sample_ip)
#shape of op (attn): (5,4,512)
在您的实施中,在 scaled_dot_product
中您使用 query
进行缩放,但根据原始论文,他们使用 key
进行归一化。除此之外,这个实现看起来还可以,但不一般。
class MultiAttention(tf.keras.layers.Layer):
def __init__(self, num_of_heads, out_dim):
super(MultiAttention,self).__init__()
self.out_dim = out_dim
self.num_of_heads = num_of_heads
self.depth = self.out_dim // self.num_of_heads
self.wq = [tf.keras.layers.Dense(self.depth) for i in range(num_of_heads)]
self.wk = [tf.keras.layers.Dense(self.depth) for i in range(num_of_heads)]
self.wv = [tf.keras.layers.Dense(self.depth) for i in range(num_of_heads)]
self.wo = tf.keras.layers.Dense(self.out_dim)
def call(self,x):
multi_attn = []
for i in range(self.num_of_heads):
Q = self.wq[i](x)
K = self.wk[i](x)
V = self.wv[i](x)
multi_attn.append(self.scaled_dot_product(Q,K,V))
multi_head = tf.concat(multi_attn, axis=-1)
multi_head_attention = self.wo(multi_head)
return multi_head_attention
def scaled_dot_product(self, q,k,v):
qkt = tf.matmul(q, k, transpose_b=True)
dk = tf.math.sqrt( tf.cast(k.shape[-1], dtype=tf.float32) )
scaled_qkt = qkt/dk
softmax = tf.nn.softmax(scaled_qkt, axis=-1)
z = tf.matmul(softmax, v)
return z
multi = MultiAttention(num_of_heads=3, out_dim=32)
sample_ip = tf.random.normal(shape=(2, 2, 32)); print(sample_ip.shape)
multi(sample_ip).shape
一般的transformer架构可以演示如下,其中前两个线性层代表query
和key
并负责产生注意权重映射,然后以矩阵乘法方式对value
进行加权。
我知道您正在尝试最小化原始 TF tutorial code,但我认为您应该首先添加对原始问题的引用。在最初的实现中,他们还返回 加权概率或分数 以及 加权特征图 。我认为你不应该跳过那个。
您关注的 original code 更通用、更有效地进行了优化。
class MultiHeadAttention(tf.keras.layers.Layer):
def __init__(self, d_model, num_heads):
super(MultiHeadAttention, self).__init__()
self.num_heads = num_heads
self.d_model = d_model
assert d_model % self.num_heads == 0
self.depth = d_model // self.num_heads
self.wq = tf.keras.layers.Dense(d_model)
self.wk = tf.keras.layers.Dense(d_model)
self.wv = tf.keras.layers.Dense(d_model)
self.dense = tf.keras.layers.Dense(d_model)
def scaled_dot_product_attention(self, q, k, v, mask=None):
matmul_qk = tf.matmul(q, k, transpose_b=True) # (..., seq_len_q, seq_len_k)
# scale matmul_qk
dk = tf.cast(tf.shape(k)[-1], tf.float32)
scaled_attention_logits = matmul_qk / tf.math.sqrt(dk)
# add the mask to the scaled tensor.
if mask is not None: scaled_attention_logits += (mask * -1e9)
# softmax is normalized on the last axis (seq_len_k) so that the scores
# add up to 1.
attention_weights = tf.nn.softmax(scaled_attention_logits, axis=-1) # (..., seq_len_q, seq_len_k)
output = tf.matmul(attention_weights, v) # (..., seq_len_q, depth_v)
return output, attention_weights
def split_heads(self, x, batch_size):
"""Split the last dimension into (num_heads, depth).
Transpose the result such that the shape is (batch_size, num_heads, seq_len, depth)
"""
x = tf.reshape(x, (batch_size, -1, self.num_heads, self.depth))
return tf.transpose(x, perm=[0, 2, 1, 3])
def call(self, v, k, q, mask=None):
batch_size = tf.shape(q)[0]
q = self.wq(q) # (batch_size, seq_len, d_model)
k = self.wk(k) # (batch_size, seq_len, d_model)
v = self.wv(v) # (batch_size, seq_len, d_model)
q = self.split_heads(q, batch_size) # (batch_size, num_heads, seq_len_q, depth)
k = self.split_heads(k, batch_size) # (batch_size, num_heads, seq_len_k, depth)
v = self.split_heads(v, batch_size) # (batch_size, num_heads, seq_len_v, depth)
# scaled_attention.shape == (batch_size, num_heads, seq_len_q, depth)
# attention_weights.shape == (batch_size, num_heads, seq_len_q, seq_len_k)
scaled_attention, attention_weights = self.scaled_dot_product_attention(q, k, v, mask)
scaled_attention = tf.transpose(scaled_attention, perm=[0, 2, 1, 3]) # (batch_size, seq_len_q, num_heads, depth)
concat_attention = tf.reshape(scaled_attention, (batch_size, -1, self.d_model)) # (batch_size, seq_len_q, d_model)
output = self.dense(concat_attention) # (batch_size, seq_len_q, d_model)
return output, attention_weights
仅供参考,在TF 2.4
中,正式添加了tf.keras.layers.MultiHeadAttention
层。
layer = tf.keras.layers.MultiHeadAttention(num_heads=2, key_dim=2)
input_tensor = tf.keras.Input(shape=[2, 2, 32]); print(input_tensor.shape)
print(layer(input_tensor, input_tensor).shape)
您可以按如下方式测试这两个:
# custom layer MHA
multi = MultiHeadAttention(d_model=512, num_heads=2)
y = tf.random.uniform((1, 60, 512))
out, attn = multi(y, k=y, q=y, mask=None)
out.shape, attn.shape
(TensorShape([1, 60, 512]), TensorShape([1, 2, 60, 60]))
# built-in layer
layer = tf.keras.layers.MultiHeadAttention(num_heads=2, key_dim=2)
y = tf.random.uniform((1, 60, 512))
out, attn = layer(y, y, return_attention_scores=True)
out.shape, attn.shape
(TensorShape([1, 60, 512]), TensorShape([1, 2, 60, 60]))
我已经在 Transformers
中实现了 MultiAttention head
。周围有太多的实现,所以很混乱。有人可以验证我的实现是否正确吗:
DotProductAttention 引用自:https://www.tensorflow.org/tutorials/text/transformer#setup
import tensorflow as tf
def scaled_dot_product(q,k,v):
#calculates Q . K(transpose)
qkt = tf.matmul(q,k,transpose_b=True)
#caculates scaling factor
dk = tf.math.sqrt(tf.cast(q.shape[-1],dtype=tf.float32))
scaled_qkt = qkt/dk
softmax = tf.nn.softmax(scaled_qkt,axis=-1)
z = tf.matmul(softmax,v)
#shape: (m,Tx,depth), same shape as q,k,v
return z
class MultiAttention(tf.keras.layers.Layer):
def __init__(self,d_model,num_of_heads):
super(MultiAttention,self).__init__()
self.d_model = d_model
self.num_of_heads = num_of_heads
self.depth = d_model//num_of_heads
self.wq = [tf.keras.layers.Dense(self.depth) for i in range(num_of_heads)]
self.wk = [tf.keras.layers.Dense(self.depth) for i in range(num_of_heads)]
self.wv = [tf.keras.layers.Dense(self.depth) for i in range(num_of_heads)]
self.wo = tf.keras.layers.Dense(d_model)
def call(self,x):
multi_attn = []
for i in range(self.num_of_heads):
Q = self.wq[i](x)
K = self.wk[i](x)
V = self.wv[i](x)
multi_attn.append(scaled_dot_product(Q,K,V))
multi_head = tf.concat(multi_attn,axis=-1)
multi_head_attention = self.wo(multi_head)
return multi_head_attention
#Calling the attention
multi = MultiAttention(d_model=512,num_of_heads=8)
m = 5; sequence_length = 4; word_embedding_dim = 512
sample_ip = tf.constant(tf.random.normal(shape=(m,sequence_length,word_embedding_dim)))
attn =multi(sample_ip)
#shape of op (attn): (5,4,512)
在您的实施中,在 scaled_dot_product
中您使用 query
进行缩放,但根据原始论文,他们使用 key
进行归一化。除此之外,这个实现看起来还可以,但不一般。
class MultiAttention(tf.keras.layers.Layer):
def __init__(self, num_of_heads, out_dim):
super(MultiAttention,self).__init__()
self.out_dim = out_dim
self.num_of_heads = num_of_heads
self.depth = self.out_dim // self.num_of_heads
self.wq = [tf.keras.layers.Dense(self.depth) for i in range(num_of_heads)]
self.wk = [tf.keras.layers.Dense(self.depth) for i in range(num_of_heads)]
self.wv = [tf.keras.layers.Dense(self.depth) for i in range(num_of_heads)]
self.wo = tf.keras.layers.Dense(self.out_dim)
def call(self,x):
multi_attn = []
for i in range(self.num_of_heads):
Q = self.wq[i](x)
K = self.wk[i](x)
V = self.wv[i](x)
multi_attn.append(self.scaled_dot_product(Q,K,V))
multi_head = tf.concat(multi_attn, axis=-1)
multi_head_attention = self.wo(multi_head)
return multi_head_attention
def scaled_dot_product(self, q,k,v):
qkt = tf.matmul(q, k, transpose_b=True)
dk = tf.math.sqrt( tf.cast(k.shape[-1], dtype=tf.float32) )
scaled_qkt = qkt/dk
softmax = tf.nn.softmax(scaled_qkt, axis=-1)
z = tf.matmul(softmax, v)
return z
multi = MultiAttention(num_of_heads=3, out_dim=32)
sample_ip = tf.random.normal(shape=(2, 2, 32)); print(sample_ip.shape)
multi(sample_ip).shape
一般的transformer架构可以演示如下,其中前两个线性层代表query
和key
并负责产生注意权重映射,然后以矩阵乘法方式对value
进行加权。
我知道您正在尝试最小化原始 TF tutorial code,但我认为您应该首先添加对原始问题的引用。在最初的实现中,他们还返回 加权概率或分数 以及 加权特征图 。我认为你不应该跳过那个。
您关注的 original code 更通用、更有效地进行了优化。
class MultiHeadAttention(tf.keras.layers.Layer):
def __init__(self, d_model, num_heads):
super(MultiHeadAttention, self).__init__()
self.num_heads = num_heads
self.d_model = d_model
assert d_model % self.num_heads == 0
self.depth = d_model // self.num_heads
self.wq = tf.keras.layers.Dense(d_model)
self.wk = tf.keras.layers.Dense(d_model)
self.wv = tf.keras.layers.Dense(d_model)
self.dense = tf.keras.layers.Dense(d_model)
def scaled_dot_product_attention(self, q, k, v, mask=None):
matmul_qk = tf.matmul(q, k, transpose_b=True) # (..., seq_len_q, seq_len_k)
# scale matmul_qk
dk = tf.cast(tf.shape(k)[-1], tf.float32)
scaled_attention_logits = matmul_qk / tf.math.sqrt(dk)
# add the mask to the scaled tensor.
if mask is not None: scaled_attention_logits += (mask * -1e9)
# softmax is normalized on the last axis (seq_len_k) so that the scores
# add up to 1.
attention_weights = tf.nn.softmax(scaled_attention_logits, axis=-1) # (..., seq_len_q, seq_len_k)
output = tf.matmul(attention_weights, v) # (..., seq_len_q, depth_v)
return output, attention_weights
def split_heads(self, x, batch_size):
"""Split the last dimension into (num_heads, depth).
Transpose the result such that the shape is (batch_size, num_heads, seq_len, depth)
"""
x = tf.reshape(x, (batch_size, -1, self.num_heads, self.depth))
return tf.transpose(x, perm=[0, 2, 1, 3])
def call(self, v, k, q, mask=None):
batch_size = tf.shape(q)[0]
q = self.wq(q) # (batch_size, seq_len, d_model)
k = self.wk(k) # (batch_size, seq_len, d_model)
v = self.wv(v) # (batch_size, seq_len, d_model)
q = self.split_heads(q, batch_size) # (batch_size, num_heads, seq_len_q, depth)
k = self.split_heads(k, batch_size) # (batch_size, num_heads, seq_len_k, depth)
v = self.split_heads(v, batch_size) # (batch_size, num_heads, seq_len_v, depth)
# scaled_attention.shape == (batch_size, num_heads, seq_len_q, depth)
# attention_weights.shape == (batch_size, num_heads, seq_len_q, seq_len_k)
scaled_attention, attention_weights = self.scaled_dot_product_attention(q, k, v, mask)
scaled_attention = tf.transpose(scaled_attention, perm=[0, 2, 1, 3]) # (batch_size, seq_len_q, num_heads, depth)
concat_attention = tf.reshape(scaled_attention, (batch_size, -1, self.d_model)) # (batch_size, seq_len_q, d_model)
output = self.dense(concat_attention) # (batch_size, seq_len_q, d_model)
return output, attention_weights
仅供参考,在TF 2.4
中,正式添加了tf.keras.layers.MultiHeadAttention
层。
layer = tf.keras.layers.MultiHeadAttention(num_heads=2, key_dim=2)
input_tensor = tf.keras.Input(shape=[2, 2, 32]); print(input_tensor.shape)
print(layer(input_tensor, input_tensor).shape)
您可以按如下方式测试这两个:
# custom layer MHA
multi = MultiHeadAttention(d_model=512, num_heads=2)
y = tf.random.uniform((1, 60, 512))
out, attn = multi(y, k=y, q=y, mask=None)
out.shape, attn.shape
(TensorShape([1, 60, 512]), TensorShape([1, 2, 60, 60]))
# built-in layer
layer = tf.keras.layers.MultiHeadAttention(num_heads=2, key_dim=2)
y = tf.random.uniform((1, 60, 512))
out, attn = layer(y, y, return_attention_scores=True)
out.shape, attn.shape
(TensorShape([1, 60, 512]), TensorShape([1, 2, 60, 60]))