TensorFlow 检查点变量未保存

TensorFlow Checkpoint variables not saved

我正在尝试将 Checkpoint 用于我的模型。在此之前用玩具示例尝试过。这 运行s 没有错误。但是我每次运行,貌似训练参数都是从初始值开始的。不确定我是否在这里遗漏了什么?以下是我使用的代码:

import numpy as np
import tensorflow as tf

X = tf.range(10.)
Y = 50.*X
    
class CGMM(object):
    def __init__(self):
        self.beta =  tf.Variable(1. , dtype=np.float32)

    @tf.function
    def objfun(self):
        beta = self.beta
        obj = tf.reduce_mean(tf.square(beta*self.X - self.Y))
        return obj

    def build_model(self,X,Y):
        self.X,self.Y=X,Y
        optimizer = tf.keras.optimizers.RMSprop(0.5)
        ckpt = tf.train.Checkpoint(step=tf.Variable(1),model =self.objfun ,optimizer=optimizer)
        manager = tf.train.CheckpointManager(ckpt, './tf_ckpts_cg', max_to_keep=3)

        ckpt.restore(manager.latest_checkpoint)
        if manager.latest_checkpoint:
            print("Restored from {}".format(manager.latest_checkpoint))
        else:
            print("Initializing from scratch.")

        for i in range(20):
            optimizer.minimize(self.objfun,var_list =  self.beta)
            loss, beta = self.objfun(), self.beta
            # print(self.beta.numpy())
            ckpt.step.assign_add(1)
            if int(ckpt.step) % 5 == 0:
              save_path = manager.save()
              print("Saved checkpoint for step {}: {}".format(int(ckpt.step), save_path))
              print("loss {:1.2f}".format(loss.numpy()))
              print("beta {:1.2f}".format(beta.numpy()))

        return beta


model =CGMM()
opt_beta = model.build_model(X,Y)

第 1 个结果 运行:

Initializing from scratch.
Saved checkpoint for step 5: ./tf_ckpts_cg/ckpt-1
loss 56509.74
beta 5.47
Saved checkpoint for step 10: ./tf_ckpts_cg/ckpt-2
loss 48354.54
beta 8.81
Saved checkpoint for step 15: ./tf_ckpts_cg/ckpt-3
loss 42085.54
beta 11.57
Saved checkpoint for step 20: ./tf_ckpts_cg/ckpt-4
loss 36750.57
beta 14.09

第二个结果运行:

Restored from ./tf_ckpts_cg/ckpt-4
Saved checkpoint for step 25: ./tf_ckpts_cg/ckpt-5
loss 54619.16
beta 6.22
Saved checkpoint for step 30: ./tf_ckpts_cg/ckpt-6
loss 46997.79
beta 9.39
Saved checkpoint for step 35: ./tf_ckpts_cg/ckpt-7
loss 40958.30
beta 12.09
Saved checkpoint for step 40: ./tf_ckpts_cg/ckpt-8
loss 35763.21
beta 14.58

问题是什么:

主要问题是您的 beta 变量不可跟踪:这意味着检查点对象不会保存它。通过使用以下函数检查检查点的内容,我们可以看到:

>>> tf.train.list_variables(tf.train.latest_checkpoint('./tf_ckpts_cg/')

[('_CHECKPOINTABLE_OBJECT_GRAPH', []),
 ('optimizer/decay/.ATTRIBUTES/VARIABLE_VALUE', []),
 ('optimizer/iter/.ATTRIBUTES/VARIABLE_VALUE', []),
 ('optimizer/learning_rate/.ATTRIBUTES/VARIABLE_VALUE', []),
 ('optimizer/momentum/.ATTRIBUTES/VARIABLE_VALUE', []),
 ('optimizer/rho/.ATTRIBUTES/VARIABLE_VALUE', []),
 ('save_counter/.ATTRIBUTES/VARIABLE_VALUE', []),
 ('step/.ATTRIBUTES/VARIABLE_VALUE', [])]

检查点唯一跟踪的 tf.Variables 是来自优化器的和 tf.train.Checkpoint 对象本身使用的。


可能的解决方案:

要改变这一点,您需要跟踪您的变量。关于该主题的 TensorFlow 文档不是很好,但经过一番搜索后,您可以在 tf.Variable 文档中阅读以下内容:

Variables are automatically tracked when assigned to attributes of types inheriting from tf.Module.

[...]

This tracking then allows saving variable values to training checkpoints, or to SavedModels which include serialized TensorFlow graphs.

因此,通过让您的 CGMM class 继承自 tf.Module,您可以跟踪您的 beta 变量,并恢复它!这是对您的代码的一个非常直接的更改:

class CGMM(tf.Module):
    def __init__(self):
        super(CGMM, self).__init__(name='CGMM')
        self.beta =  tf.Variable(1. , dtype=np.float32)

我们还需要告诉 Checkpoint 对象模型现在是 CGMM 对象:

ckpt = tf.train.Checkpoint(step=tf.Variable(1) ,model=self, optimizer=optimizer)

现在,如果我们训练几个步骤并查看检查点文件的内容,我们会得到一些有希望的东西。 beta 变量现已保存:

>>> tf.train.list_variables(tf.train.latest_checkpoint('./tf_ckpts_cg/'))

[('_CHECKPOINTABLE_OBJECT_GRAPH', []),
 ('model/beta/.ATTRIBUTES/VARIABLE_VALUE', []),
 ('model/beta/.OPTIMIZER_SLOT/optimizer/rms/.ATTRIBUTES/VARIABLE_VALUE', []),
 ('optimizer/decay/.ATTRIBUTES/VARIABLE_VALUE', []),
 ('optimizer/iter/.ATTRIBUTES/VARIABLE_VALUE', []),
 ('optimizer/learning_rate/.ATTRIBUTES/VARIABLE_VALUE', []),
 ('optimizer/momentum/.ATTRIBUTES/VARIABLE_VALUE', []),
 ('optimizer/rho/.ATTRIBUTES/VARIABLE_VALUE', []),
 ('save_counter/.ATTRIBUTES/VARIABLE_VALUE', []),
 ('step/.ATTRIBUTES/VARIABLE_VALUE', [])]

如果我们运行程序几次,我们得到:

>>> run tf-ckpt.py

Restored from ./tf_ckpts_cg/ckpt-28
Saved checkpoint for step 145: ./tf_ckpts_cg/ckpt-29
loss 0.00
beta 49.99
Saved checkpoint for step 150: ./tf_ckpts_cg/ckpt-30
loss 0.00
beta 50.00

万岁!


注意:为了跟踪变量,您还可以使用任何一种keras.layers.Layer以及任何keras.Model。这可能是最简单的方法。

摘自 Training Checkpoint guide :

Subclasses of tf.train.Checkpoint, tf.keras.layers.Layer, and tf.keras.Model automatically track variables assigned to their attributes.