Keras Transformer官方例子中的解释attention

Interpreting attention in Keras Transformer official example

我已经实现了一个模型,如(使用 Transformer 进行文本分类)https://keras.io/examples/nlp/text_classification_with_transformer/

我想访问特定示例的注意力值。

我了解注意力是围绕这一点计算的:

class TransformerBlock(layers.Layer):
    [...]

def call(self, inputs, training):
    attn_output = self.att(inputs)
    attn_output = self.dropout1(attn_output, training=training)
    out1 = self.layernorm1(inputs + attn_output)
    ffn_output = self.ffn(out1)
    ffn_output = self.dropout2(ffn_output, training=training)
    return self.layernorm2(out1 + ffn_output)

[...]

embed_dim = 32  # Embedding size for each token

num_heads = 2  # Number of attention heads
ff_dim = 32  # Hidden layer size in feed forward network inside transformer

inputs = layers.Input(shape=(maxlen,))
embedding_layer = TokenAndPositionEmbedding(maxlen, vocab_size, embed_dim)
x = embedding_layer(inputs)
transformer_block = TransformerBlock(embed_dim, num_heads, ff_dim)
x = transformer_block(x)
x = layers.GlobalAveragePooling1D()(x)
x = layers.Dropout(0.1)(x)
x = layers.Dense(20, activation="relu")(x)
x = layers.Dropout(0.1)(x)
outputs = layers.Dense(2, activation="softmax")(x)

如果我这样做:

A=(model.layers[2].att(model.layers[1](model.layers[0]((X_train[0,:])))))

我可以检索大小为 maxlen xnum_heads 的矩阵。

我应该如何解释这些系数?

编辑:如果您想使用注意力来解释分类输出

据我所知,无法完全解释 Transformer 在分类方面的作用。 Transformer 所做的只是查看每个输入如何相互关联,而不是查看每个单词如何对标签做出贡献。如果您希望找到可解释的模型,请尝试查看 LSTM attention for classification。

好的,我已经阅读了您的代码并在您调用 model.layers[1] 时发现了一些错误。首先,您需要了解模型正在批量处理数据。因此,您的输入格式应为(batch_size, seq_len)。但是,您的输入形状似乎降低了第一个维度(即批处理),这使您的模型认为您正在给模型 200 个序列长度为 1 的句子。因此,如图所示,输出形状看起来很奇怪。

正确的方法是在第一个维度上增加一个维度(使用tf.expand_dims)。

现在,解释结果。您需要知道 Transformer 块执行 self-attention(找到句子中每个单词对其他单词的分数)并对其进行加权求和。因此,输出将与嵌入层相同,您将无法解释它(因为它是由网络生成的隐藏向量)。

但是,您可以使用以下代码查看每个头部的注意力分数:

import seaborn as sns
import matplotlib.pyplot as plt

head_num=1
inp = tf.expand_dims(x_train[0,:], axis=0)
emb = model.layers[1](model.layers[0]((inp)))

self_attn = model.layers[2].att
# compute Q, K, V
query = self_attn.query_dense(emb)
key = self_attn.key_dense(emb)
value = self_attn.value_dense(emb)
# separate heads
query = self_attn.separate_heads(query, 1) # batch_size = 1
key = self_attn.separate_heads(key, 1) # batch_size = 1
value = self_attn.separate_heads(value, 1) # batch_size = 1
# compute attention scores (QK^T)
attention, weights = self_attn.attention(query, key, value)

idx_word = {v: k for k, v in keras.datasets.imdb.get_word_index().items()}
plt.figure(figsize=(30, 30))
sns.heatmap(
    weights.numpy()[0][head_num], 
    xticklabels=[idx_word[idx] for idx in inp[0].numpy()],
    yticklabels=[idx_word[idx] for idx in inp[0].numpy()]
)

这是一个示例输出: