在 Keras 的 BERT 编码器之上堆叠 LSTM 层

stacking LSTM layer on top of BERT encoder in Keras

我一直在尝试在 Bert 嵌入之上堆叠一个 LSTM 层,但是当我的模型开始训练时,它在最后一批中失败并抛出以下错误消息:

    Node: 'model/tf.reshape/Reshape'
Input to reshape is a tensor with 59136 values, but the requested shape has 98304
         [[{{node model/tf.reshape/Reshape}}]] [Op:__inference_train_function_70500]

这就是我构建模型的方式,老实说,我无法弄清楚这里出了什么问题:

batch_size = 128


bert_preprocess = hub.load('https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/3')
bert_encoder = hub.KerasLayer('https://tfhub.dev/tensorflow/bert_en_uncased_L-12_H-768_A-12/4', trainable=True)

text_input = tf.keras.layers.Input(shape=(), dtype=tf.string, name='text')
preprocessed_text = bert_preprocess(text_input)
outputs = bert_encoder(preprocessed_text)  #shape=(None, 768)
bert_output = outputs['pooled_output']

l = tf.reshape(bert_output, [batch_size, 1, 768])

l = tf.keras.layers.LSTM(32, activation='relu')(l)


l = tf.keras.layers.Dropout(0.1, name='dropout')(l)
l = tf.keras.layers.Dense(8, activation='softmax', name="output")(l)
model = tf.keras.Model(inputs=[text_input], outputs = [l])

print(model.summary())
model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])
model.fit(x_train, y_train, epochs=1, batch_size = batch_size)

这是完整的输出:

Model: "model"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
==================================================================================================
 text (InputLayer)              [(None,)]            0           []                               
                                                                                                  
 keras_layer (KerasLayer)       {'input_mask': (Non  0           ['text[0][0]']                   
                                e, 128),                                                          
                                 'input_word_ids':                                                
                                (None, 128),                                                      
                                 'input_type_ids':                                                
                                (None, 128)}                                                      
                                                                                                  
 keras_layer_1 (KerasLayer)     {'sequence_output':  109482241   ['keras_layer[0][0]',            
                                 (None, 128, 768),                'keras_layer[0][1]',            
                                 'default': (None,                'keras_layer[0][2]']            
                                768),                                                             
                                 'encoder_outputs':                                               
                                 [(None, 128, 768),                                               
                                 (None, 128, 768),                                                
                                 (None, 128, 768),                                                
                                 (None, 128, 768),                                                
                                 (None, 128, 768),                                                
                                 (None, 128, 768),                                                
                                 (None, 128, 768),                                                
                                 (None, 128, 768),                                                
                                 (None, 128, 768),                                                
                                 (None, 128, 768),                                                
                                 (None, 128, 768),                                                
                                 (None, 128, 768)],                                               
                                 'pooled_output': (                                               
                                None, 768)}                                                       
                                                                                                  
 tf.reshape (TFOpLambda)        (128, 1, 768)        0           ['keras_layer_1[0][13]']         
                                                                                                  
 lstm (LSTM)                    (128, 32)            102528      ['tf.reshape[0][0]']             
                                                                                                  
 dropout (Dropout)              (128, 32)            0           ['lstm[0][0]']                   
                                                                                                  
 output (Dense)                 (128, 8)             264         ['dropout[0][0]']                
                                                                                                  
==================================================================================================
Total params: 109,585,033
Trainable params: 102,792
Non-trainable params: 109,482,241
__________________________________________________________________________________________________
None
    WARNING:tensorflow:AutoGraph could not transform <function Model.make_train_function.<locals>.train_function at 0x7fc4ff809440> and will run it as-is.
    Please report this to the TensorFlow team. When filing the bug, set the verbosity to 10 (on Linux, `export AUTOGRAPH_VERBOSITY=10`) and attach the full output.
    Cause: 'arguments' object has no attribute 'posonlyargs'
    To silence this warning, decorate the function with @tf.autograph.experimental.do_not_convert
    WARNING:tensorflow:AutoGraph could not transform <function Model.make_train_function.<locals>.train_function at 0x7fc4ff809440> and will run it as-is.
    Please report this to the TensorFlow team. When filing the bug, set the verbosity to 10 (on Linux, `export AUTOGRAPH_VERBOSITY=10`) and attach the full output.
    Cause: 'arguments' object has no attribute 'posonlyargs'
    To silence this warning, decorate the function with @tf.autograph.experimental.do_not_convert
    18/19 [===========================>..] - ETA: 25s - loss: 1.5747 - accuracy: 0.5456Traceback (most recent call last):
      File "bert-test-lstm.py", line 62, in <module>
        model.fit(x_train, y_train, epochs=1, batch_size = batch_size)
      File "/Users/user/opt/anaconda3/envs/project/lib/python3.7/site-packages/keras/utils/traceback_utils.py", line 67, in error_handler
        raise e.with_traceback(filtered_tb) from None
      File "/Users/user/opt/anaconda3/envs/project/lib/python3.7/site-packages/tensorflow/python/eager/execute.py", line 55, in quick_execute
        inputs, attrs, num_outputs)
    tensorflow.python.framework.errors_impl.InvalidArgumentError: Graph execution error:
Detected at node 'model/tf.reshape/Reshape' defined at (most recent call last):
    File "bert-test-lstm.py", line 62, in <module>
      model.fit(x_train, y_train, epochs=1, batch_size = batch_size)
    File "/Users/user/opt/anaconda3/envs/project/lib/python3.7/site-packages/keras/utils/traceback_utils.py", line 64, in error_handler
      return fn(*args, **kwargs)
    File "/Users/user/opt/anaconda3/envs/project/lib/python3.7/site-packages/keras/engine/training.py", line 1384, in fit
      tmp_logs = self.train_function(iterator)
    File "/Users/user/opt/anaconda3/envs/project/lib/python3.7/site-packages/keras/engine/training.py", line 1021, in train_function
      return step_function(self, iterator)
    File "/Users/user/opt/anaconda3/envs/project/lib/python3.7/site-packages/keras/engine/training.py", line 1010, in step_function
      outputs = model.distribute_strategy.run(run_step, args=(data,))
    File "/Users/user/opt/anaconda3/envs/project/lib/python3.7/site-packages/keras/engine/training.py", line 1000, in run_step
      outputs = model.train_step(data)
    File "/Users/user/opt/anaconda3/envs/project/lib/python3.7/site-packages/keras/engine/training.py", line 859, in train_step
      y_pred = self(x, training=True)
    File "/Users/user/opt/anaconda3/envs/project/lib/python3.7/site-packages/keras/utils/traceback_utils.py", line 64, in error_handler
      return fn(*args, **kwargs)
    File "/Users/user/opt/anaconda3/envs/project/lib/python3.7/site-packages/keras/engine/base_layer.py", line 1096, in __call__
      outputs = call_fn(inputs, *args, **kwargs)
    File "/Users/user/opt/anaconda3/envs/project/lib/python3.7/site-packages/keras/utils/traceback_utils.py", line 92, in error_handler
      return fn(*args, **kwargs)
    File "/Users/user/opt/anaconda3/envs/project/lib/python3.7/site-packages/keras/engine/functional.py", line 452, in call
      inputs, training=training, mask=mask)
    File "/Users/user/opt/anaconda3/envs/project/lib/python3.7/site-packages/keras/engine/functional.py", line 589, in _run_internal_graph
      outputs = node.layer(*args, **kwargs)
    File "/Users/user/opt/anaconda3/envs/project/lib/python3.7/site-packages/keras/utils/traceback_utils.py", line 64, in error_handler
      return fn(*args, **kwargs)
    File "/Users/user/opt/anaconda3/envs/project/lib/python3.7/site-packages/keras/engine/base_layer.py", line 1096, in __call__
      outputs = call_fn(inputs, *args, **kwargs)
    File "/Users/user/opt/anaconda3/envs/project/lib/python3.7/site-packages/keras/utils/traceback_utils.py", line 92, in error_handler
      return fn(*args, **kwargs)
    File "/Users/user/opt/anaconda3/envs/project/lib/python3.7/site-packages/keras/layers/core/tf_op_layer.py", line 226, in _call_wrapper
      return self._call_wrapper(*args, **kwargs)
    File "/Users/user/opt/anaconda3/envs/project/lib/python3.7/site-packages/keras/layers/core/tf_op_layer.py", line 261, in _call_wrapper
      result = self.function(*args, **kwargs)
Node: 'model/tf.reshape/Reshape'
Input to reshape is a tensor with 59136 values, but the requested shape has 98304

如果我只是删除 LSTM 并重塑层,代码运行得非常好 - 任何帮助都将不胜感激。

您应该使用 tf.keras.layers.Reshape 以将 bert_output 重塑为 3D 张量并自动考虑批量维度。

简单地改变:

l = tf.reshape(bert_output, [batch_size, 1, 768])

进入:

l = tf.keras.layers.Reshape((1,768))(bert_output)

应该可以。