训练BERT Keras模型时出现OOM错误
OOM error when training the BERT Keras model
我正在使用 Keras 训练微调的 BERT 模型。但是,当我开始在 GPU 上进行训练时,我遇到了 OOM 错误。下面是我模型的代码。
max_len = 256
input_word_ids = tf.keras.layers.Input(shape=(max_len,), dtype=tf.int32, name="input_word_ids")
input_mask = tf.keras.layers.Input(shape=(max_len,), dtype=tf.int32, name="input_mask")
segment_ids = tf.keras.layers.Input(shape=(max_len,), dtype=tf.int32, name="segment_ids")
pooled_output, sequence_output = bert_layer([input_word_ids, input_mask, segment_ids])
x = Dense(128, activation = "relu")(sequence_output)
x = Dropout(0.1)(x)
out = Dense(1, activation = "sigmoid")(x)
model = Model(inputs = [input_word_ids, input_mask, segment_ids], outputs = out)
model.compile(Adam(lr = 2e-6), loss = 'binary_crossentropy', metrics = ['accuracy'])
model.summary()
train_history = model.fit(train_input, train_labels,
validation_split = 0.2,
epochs = 3,
batch_size = 16)
我得到的错误是
---------------------------------------------------------------------------
ResourceExhaustedError Traceback (most recent call last)
<timed exec> in <module>
/opt/conda/lib/python3.7/site-packages/tensorflow/python/keras/engine/training.py in fit(self, x, y, batch_size, epochs, verbose, callbacks, validation_split, validation_data, shuffle, class_weight, sample_weight, initial_epoch, steps_per_epoch, validation_steps, validation_batch_size, validation_freq, max_queue_size, workers, use_multiprocessing)
1098 _r=1):
1099 callbacks.on_train_batch_begin(step)
-> 1100 tmp_logs = self.train_function(iterator)
1101 if data_handler.should_sync:
1102 context.async_wait()
/opt/conda/lib/python3.7/site-packages/tensorflow/python/eager/def_function.py in __call__(self, *args, **kwds)
826 tracing_count = self.experimental_get_tracing_count()
827 with trace.Trace(self._name) as tm:
--> 828 result = self._call(*args, **kwds)
829 compiler = "xla" if self._experimental_compile else "nonXla"
830 new_tracing_count = self.experimental_get_tracing_count()
/opt/conda/lib/python3.7/site-packages/tensorflow/python/eager/def_function.py in _call(self, *args, **kwds)
886 # Lifting succeeded, so variables are initialized and we can run the
887 # stateless function.
--> 888 return self._stateless_fn(*args, **kwds)
889 else:
890 _, _, _, filtered_flat_args = \
/opt/conda/lib/python3.7/site-packages/tensorflow/python/eager/function.py in __call__(self, *args, **kwargs)
2941 filtered_flat_args) = self._maybe_define_function(args, kwargs)
2942 return graph_function._call_flat(
-> 2943 filtered_flat_args, captured_inputs=graph_function.captured_inputs) # pylint: disable=protected-access
2944
2945 @property
/opt/conda/lib/python3.7/site-packages/tensorflow/python/eager/function.py in _call_flat(self, args, captured_inputs, cancellation_manager)
1917 # No tape is watching; skip to running the function.
1918 return self._build_call_outputs(self._inference_function.call(
-> 1919 ctx, args, cancellation_manager=cancellation_manager))
1920 forward_backward = self._select_forward_and_backward_functions(
1921 args,
/opt/conda/lib/python3.7/site-packages/tensorflow/python/eager/function.py in call(self, ctx, args, cancellation_manager)
558 inputs=args,
559 attrs=attrs,
--> 560 ctx=ctx)
561 else:
562 outputs = execute.execute_with_cancellation(
/opt/conda/lib/python3.7/site-packages/tensorflow/python/eager/execute.py in quick_execute(op_name, num_outputs, inputs, attrs, ctx, name)
58 ctx.ensure_initialized()
59 tensors = pywrap_tfe.TFE_Py_Execute(ctx._handle, device_name, op_name,
---> 60 inputs, attrs, num_outputs)
61 except core._NotOkStatusException as e:
62 if name is not None:
ResourceExhaustedError: OOM when allocating tensor with shape[16,160,1024] and type float on /job:localhost/replica:0/task:0/device:GPU:0 by allocator GPU_0_bfc
[[{{node model_1/keras_layer/StatefulPartitionedCall/StatefulPartitionedCall/StatefulPartitionedCall/bert_model/StatefulPartitionedCall/encoder/layer_23/output_layer_norm/moments/SquaredDifference}}]]
Hint: If you want to see a list of allocated tensors when OOM happens, add report_tensor_allocations_upon_oom to RunOptions for current allocation info.
[Op:__inference_train_function_105694]
Function call stack:
train_function
此外,当我 运行 只有最后一个密集层的代码时,我似乎没有 OOM 错误。它仅在我添加 Dense(units = 128) 和 Dropout(0.1) 层时发生。
谁能帮我解决这个问题?
提前谢谢你。
快速 google 搜索让我找到了 this discussion,他在其中指出使用 validation_split
参数拆分数据会导致此 OOM 错误,解决方案是拆分数据在使用 sklearn.preprocessing.train_test_split()
或您喜欢的任何其他方式调用 model.fit()
之前。
我正在使用 Keras 训练微调的 BERT 模型。但是,当我开始在 GPU 上进行训练时,我遇到了 OOM 错误。下面是我模型的代码。
max_len = 256
input_word_ids = tf.keras.layers.Input(shape=(max_len,), dtype=tf.int32, name="input_word_ids")
input_mask = tf.keras.layers.Input(shape=(max_len,), dtype=tf.int32, name="input_mask")
segment_ids = tf.keras.layers.Input(shape=(max_len,), dtype=tf.int32, name="segment_ids")
pooled_output, sequence_output = bert_layer([input_word_ids, input_mask, segment_ids])
x = Dense(128, activation = "relu")(sequence_output)
x = Dropout(0.1)(x)
out = Dense(1, activation = "sigmoid")(x)
model = Model(inputs = [input_word_ids, input_mask, segment_ids], outputs = out)
model.compile(Adam(lr = 2e-6), loss = 'binary_crossentropy', metrics = ['accuracy'])
model.summary()
train_history = model.fit(train_input, train_labels,
validation_split = 0.2,
epochs = 3,
batch_size = 16)
我得到的错误是
---------------------------------------------------------------------------
ResourceExhaustedError Traceback (most recent call last)
<timed exec> in <module>
/opt/conda/lib/python3.7/site-packages/tensorflow/python/keras/engine/training.py in fit(self, x, y, batch_size, epochs, verbose, callbacks, validation_split, validation_data, shuffle, class_weight, sample_weight, initial_epoch, steps_per_epoch, validation_steps, validation_batch_size, validation_freq, max_queue_size, workers, use_multiprocessing)
1098 _r=1):
1099 callbacks.on_train_batch_begin(step)
-> 1100 tmp_logs = self.train_function(iterator)
1101 if data_handler.should_sync:
1102 context.async_wait()
/opt/conda/lib/python3.7/site-packages/tensorflow/python/eager/def_function.py in __call__(self, *args, **kwds)
826 tracing_count = self.experimental_get_tracing_count()
827 with trace.Trace(self._name) as tm:
--> 828 result = self._call(*args, **kwds)
829 compiler = "xla" if self._experimental_compile else "nonXla"
830 new_tracing_count = self.experimental_get_tracing_count()
/opt/conda/lib/python3.7/site-packages/tensorflow/python/eager/def_function.py in _call(self, *args, **kwds)
886 # Lifting succeeded, so variables are initialized and we can run the
887 # stateless function.
--> 888 return self._stateless_fn(*args, **kwds)
889 else:
890 _, _, _, filtered_flat_args = \
/opt/conda/lib/python3.7/site-packages/tensorflow/python/eager/function.py in __call__(self, *args, **kwargs)
2941 filtered_flat_args) = self._maybe_define_function(args, kwargs)
2942 return graph_function._call_flat(
-> 2943 filtered_flat_args, captured_inputs=graph_function.captured_inputs) # pylint: disable=protected-access
2944
2945 @property
/opt/conda/lib/python3.7/site-packages/tensorflow/python/eager/function.py in _call_flat(self, args, captured_inputs, cancellation_manager)
1917 # No tape is watching; skip to running the function.
1918 return self._build_call_outputs(self._inference_function.call(
-> 1919 ctx, args, cancellation_manager=cancellation_manager))
1920 forward_backward = self._select_forward_and_backward_functions(
1921 args,
/opt/conda/lib/python3.7/site-packages/tensorflow/python/eager/function.py in call(self, ctx, args, cancellation_manager)
558 inputs=args,
559 attrs=attrs,
--> 560 ctx=ctx)
561 else:
562 outputs = execute.execute_with_cancellation(
/opt/conda/lib/python3.7/site-packages/tensorflow/python/eager/execute.py in quick_execute(op_name, num_outputs, inputs, attrs, ctx, name)
58 ctx.ensure_initialized()
59 tensors = pywrap_tfe.TFE_Py_Execute(ctx._handle, device_name, op_name,
---> 60 inputs, attrs, num_outputs)
61 except core._NotOkStatusException as e:
62 if name is not None:
ResourceExhaustedError: OOM when allocating tensor with shape[16,160,1024] and type float on /job:localhost/replica:0/task:0/device:GPU:0 by allocator GPU_0_bfc
[[{{node model_1/keras_layer/StatefulPartitionedCall/StatefulPartitionedCall/StatefulPartitionedCall/bert_model/StatefulPartitionedCall/encoder/layer_23/output_layer_norm/moments/SquaredDifference}}]]
Hint: If you want to see a list of allocated tensors when OOM happens, add report_tensor_allocations_upon_oom to RunOptions for current allocation info.
[Op:__inference_train_function_105694]
Function call stack:
train_function
此外,当我 运行 只有最后一个密集层的代码时,我似乎没有 OOM 错误。它仅在我添加 Dense(units = 128) 和 Dropout(0.1) 层时发生。
谁能帮我解决这个问题? 提前谢谢你。
快速 google 搜索让我找到了 this discussion,他在其中指出使用 validation_split
参数拆分数据会导致此 OOM 错误,解决方案是拆分数据在使用 sklearn.preprocessing.train_test_split()
或您喜欢的任何其他方式调用 model.fit()
之前。