tf.keras.callbacks.ModelCheckpoint 对比 tf.train.Checkpoint

tf.keras.callbacks.ModelCheckpoint vs tf.train.Checkpoint

我对 TensorFlow 世界有点陌生,但已经用 Keras 编写了一些程序。由于 TensorFlow 2 官方类似于 Keras,所以我很困惑 tf.keras.callbacks.ModelCheckpoint 和 tf.train.Checkpoint 之间有什么区别。如果有人能阐明这一点,我将不胜感激。

TensorFlow 是一个 'computation' 库,而 Keras 是一个深度学习库,可以与 TF 或 PyTorch 等一起使用。所以 TF 提供的是一个更通用的非定制深度学习库版本。如果你只是比较文档,你会发现 ModelCheckpoint 是多么的全面和定制化。 Checkpoint 只是读写 from/to 磁盘。 ModelCheckpoint 更聪明!

此外,ModelCheckpoint是回调。这意味着您可以创建它的实例并将其传递给 fit 函数:

model_checkpoint = ModelCheckpoint(...)
model.fit(..., callbacks=[..., model_checkpoint, ...], ...)

我快速浏览了 Keras 的 ModelCheckpoint 实现,它在 Model 上调用 savesave_weights 方法,在某些情况下使用 TensorFlow 的 CheckPoint 本身。所以它本身不是包装器,但肯定处于较低的抽象级别——更专门用于保存 Keras 模型。

这取决于是否需要自定义训练循环。在大多数情况下,它不是,您可以只调用 model.fit() 并传递 tf.keras.callbacks.ModelCheckpoint。如果您确实需要编写自定义训练循环,则必须使用 tf.train.Checkpoint(和 tf.train.CheckpointManager),因为没有回调机制。

当我看别人的代码时,我也很难区分使用的检查点对象,所以我写下了一些关于何时使用哪个以及一般如何使用它们的注释。 无论哪种方式,我认为它可能对遇到相同问题的其他人有用:

保存模型检查点

这些是保存模型检查点的 2 种方法,每种方法用于不同的用例:

1) 检查点和检查点管理器

当您自己管理训练循环时,这非常有用。

你这样使用它们:

1.1) 检查点

来自 docs 的定义: “可以构造检查点对象以将单个或一组可跟踪对象保存到检查点文件”。

如何初始化:

  • 您可以将键值对传递给它:
    • 构成您的模型并且您想跟踪的所有自定义函数调用或对象:
    • 像生成器、判别器、损失函数、优化器等
ckpt = Checkpoint(discr_opt=discr_opt,
                  genrt_opt=genrt_opt,
                  wgan = wgan,
                  d_model = d_model,
                  g_model = g_model)
1.2) 检查点管理器

这从字面上管理您定义的要存储在某个位置的检查点以及要保留的数量。 来自 docs 的定义: "通过保留一些并删除不需要的检查点来管理多个检查点"

如何初始化:

  • 使用您创建的 CheckPoint 对象作为第一个参数对其进行初始化。
  • 保存检查点文件的目录。
  • 您可能想定义要保留多少,因为这可能是很多复杂的模型
manager = CheckpointManager(ckpt, "training_checkpoints_wgan", max_to_keep=3)

使用方法:

  • 我们已经使用指定的检查点设置了管理器对象,因此可以使用了。
  • 在每个训练时期结束时调用它
manager.save()

2) 模型检查点(回调)

您想在自己不管理纪元迭代时使用此 回调。例如,当您设置了一个相对简单的顺序模型并调用 model.fit() 时,它会为您管理训练过程。

来自 docs 的定义: “以某种频率保存 Keras 模型或模型权重的回调。

如何初始化:

  • 传入模型保存路径

  • 选项save_weights_only默认设置为False:

    • 如果您只想保存权重,请务必更新此项
  • 选项save_best_only默认设置为False:

    • 如果你只想保存最好的模型而不是所有模型,你可以将它设置为 True。
  • verbose 设置为 0 (False),因此您可以将其更新为 1 以验证它

mc = ModelCheckpoint("training_checkpoints/cp.ckpt", save_best_only=True, save_weights_only=False)

使用方法:

  • 模型检查点回调现在已准备好进行训练。
  • 当你拟合模型时,你将你体内的对象传递到你的回调列表中:
model.fit(X, y, epochs=100, callbacks=[mc])