BERT 输出说明
BERT outputs explained
BERT编码器输出的key是default
、encoder_outputs
、pooled_output
和sequence_output
据我所知,encoder_outputs
是每个encoder的输出,pooled_output
是全局上下文的输出,sequence_output
是每个token的输出上下文(如果我错了请纠正我)。但是 default
是什么?你能给我一个更详细的解释吗?
Tensorflow docs 对您询问的输出提供了很好的解释:
The BERT models return a map with 3 important keys: pooled_output, sequence_output, encoder_outputs:
pooled_output represents each input sequence as a whole. The shape is
[batch_size, H]. You can think of this as an embedding for the entire
movie review.
sequence_output represents each input token in the context. The shape
is [batch_size, seq_length, H]. You can think of this as a contextual
embedding for every token in the movie review.
encoder_outputs are the
intermediate activations of the L Transformer blocks.
outputs["encoder_outputs"][i] is a Tensor of shape [batch_size,
seq_length, 1024] with the outputs of the i-th Transformer block, for
0 <= i < L. The last value of the list is equal to sequence_output
Here 是另一个关于 pooled_output
和 sequence_output
之间区别的有趣讨论,如果您有兴趣的话。
default
输出等于 pooled_output
,您可以在这里确认:
import tensorflow as tf
import tensorflow_hub as hub
tfhub_handle_preprocess = 'https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/3'
tfhub_handle_encoder = 'https://tfhub.dev/tensorflow/small_bert/bert_en_uncased_L-4_H-512_A-8/1'
def build_classifier_model(name):
text_input = tf.keras.layers.Input(shape=(), dtype=tf.string, name='features')
bert_preprocess_model = hub.KerasLayer(tfhub_handle_preprocess, name='preprocessing')
encoder_inputs = bert_preprocess_model(text_input)
encoder = hub.KerasLayer(tfhub_handle_encoder)
outputs = encoder(encoder_inputs)
net = outputs[name]
return tf.keras.Model(text_input, net)
sentence = tf.constant([
"Improve the physical fitness of your goldfish by getting him a bicycle"
])
classifier_model = build_classifier_model(name='default')
default_output = classifier_model(sentence)
classifier_model = build_classifier_model(name='pooled_output')
pooled_output = classifier_model(sentence)
print(default_output == pooled_output)
BERT编码器输出的key是default
、encoder_outputs
、pooled_output
和sequence_output
据我所知,encoder_outputs
是每个encoder的输出,pooled_output
是全局上下文的输出,sequence_output
是每个token的输出上下文(如果我错了请纠正我)。但是 default
是什么?你能给我一个更详细的解释吗?
Tensorflow docs 对您询问的输出提供了很好的解释:
The BERT models return a map with 3 important keys: pooled_output, sequence_output, encoder_outputs:
pooled_output represents each input sequence as a whole. The shape is [batch_size, H]. You can think of this as an embedding for the entire movie review.
sequence_output represents each input token in the context. The shape is [batch_size, seq_length, H]. You can think of this as a contextual embedding for every token in the movie review.
encoder_outputs are the intermediate activations of the L Transformer blocks. outputs["encoder_outputs"][i] is a Tensor of shape [batch_size, seq_length, 1024] with the outputs of the i-th Transformer block, for 0 <= i < L. The last value of the list is equal to sequence_output
Here 是另一个关于 pooled_output
和 sequence_output
之间区别的有趣讨论,如果您有兴趣的话。
default
输出等于 pooled_output
,您可以在这里确认:
import tensorflow as tf
import tensorflow_hub as hub
tfhub_handle_preprocess = 'https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/3'
tfhub_handle_encoder = 'https://tfhub.dev/tensorflow/small_bert/bert_en_uncased_L-4_H-512_A-8/1'
def build_classifier_model(name):
text_input = tf.keras.layers.Input(shape=(), dtype=tf.string, name='features')
bert_preprocess_model = hub.KerasLayer(tfhub_handle_preprocess, name='preprocessing')
encoder_inputs = bert_preprocess_model(text_input)
encoder = hub.KerasLayer(tfhub_handle_encoder)
outputs = encoder(encoder_inputs)
net = outputs[name]
return tf.keras.Model(text_input, net)
sentence = tf.constant([
"Improve the physical fitness of your goldfish by getting him a bicycle"
])
classifier_model = build_classifier_model(name='default')
default_output = classifier_model(sentence)
classifier_model = build_classifier_model(name='pooled_output')
pooled_output = classifier_model(sentence)
print(default_output == pooled_output)