使用 fit 恢复训练会将批处理步骤重置为 0

Resuming training with fit resets batch step to 0

我最近重写了我的 TensorFlow 2 自定义训练循环以使用 fit,使用 ModelCheckpoint callback to manage the checkpointing I was previously doing manually in the loop. This is all working nicely, but I have one issue which I've been struggling with: resuming training at the correct batch step. I use the TensorBoard 回调和 update_freq=50,当我恢复训练时(从保存的加载权重检查点)我通常会看到如下内容:

以上结果来自 2 运行s,单个 epoch 包含 250 个批处理步骤(只是一个玩具数据集)。顶行是 2 个纪元的第一个 运行,在 500 步之后结束(摘要每 50 步更新一次,但不是在最后一步之后,大概是这样,所以缺少 500 步处的最后一步)。中间的直线只是绘制到第二个 运行 开头的图形线,由底线表示。我 运行 对于 3 个时期,因此有 750 个批处理步骤 (250 * 3)。

问题是每次重新开始训练时,步数都会从 0 重新开始。我怎样才能解决这个问题?据推测,TensorBoard 回调正在从 0 开始跟踪每个纪元中的步骤... fit 方法有一个 initial_epoch 参数,我用它在正确的纪元重新启动,但是是否可以跟踪全局批处理步骤?我在旧的(TF2 之前的)代码中看到 global_step,是用来实现这个的吗?

好吧,似乎管理这个的方法是通过 runs 的概念,这相当于为不同的培训课程创建单独的日志子目录 - 如 [=10= 中所示].