通过模型检查点时 Pytorch 闪电出错

Getting error with Pytorch lightning when passing model checkpoint

我正在使用 Hugging 人脸模型训练多标签分类问题。我正在使用 Pytorch Lightning 来训练模型。

代码如下:

当损失最后一次没有改善时触发提前停止

early_stopping_callback = EarlyStopping(monitor='val_loss', patience=2)

我们可以开始训练过程了:

checkpoint_callback = ModelCheckpoint(
  dirpath="checkpoints",
  filename="best-checkpoint",
  save_top_k=1,
  verbose=True,
  monitor="val_loss",
  mode="min"
)


trainer = pl.Trainer(
  logger=logger,
  callbacks=[early_stopping_callback],
  max_epochs=N_EPOCHS,
 checkpoint_callback=checkpoint_callback,
  gpus=1,
  progress_bar_refresh_rate=30
)
# checkpoint_callback=checkpoint_callback,

我一运行这个,我就得到这个错误:

~/.local/lib/python3.6/site-packages/pytorch_lightning/trainer/connectors/callback_connector.py in _configure_checkpoint_callbacks(self, checkpoint_callback)
     75             if isinstance(checkpoint_callback, Callback):
     76                 error_msg += " Pass callback instances to the `callbacks` argument in the Trainer constructor instead."
---> 77             raise MisconfigurationException(error_msg)
     78         if self._trainer_has_checkpoint_callbacks() and checkpoint_callback is False:
     79             raise MisconfigurationException(

MisconfigurationException: Invalid type provided for checkpoint_callback: Expected bool but received <class 'pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint'>. Pass callback instances to the `callbacks` argument in the Trainer constructor instead.

我该如何解决这个问题?

您可以在 pl.Trainer 的文档页面中查找 checkpoint_callback 参数的描述:

checkpoint_callback (bool) – If True, enable checkpointing. It will configure a default ModelCheckpoint callback if there is no user-defined ModelCheckpoint in callbacks.

您不应将自定义 ModelCheckpoint 传递给此参数。我相信您要做的是在 callbacks list:

中传递 EarlyStoppingModelCheckpoint
early_stopping_callback = EarlyStopping(monitor='val_loss', patience=2)

checkpoint_callback = ModelCheckpoint(
    dirpath="checkpoints",
    filename="best-checkpoint",
    save_top_k=1,
    verbose=True,
    monitor="val_loss",
    mode="min")

trainer = pl.Trainer(
    logger=logger,
    callbacks=[checkpoint_callback, early_stopping_callback],
    max_epochs=N_EPOCHS,
    gpus=1,
    progress_bar_refresh_rate=30)