我如何使用 Tensorflow.Checkpoint 来恢复以前训练过的网络

How can I use Tensorflow.Checkpoint to recover a previously trained net

我正在尝试了解如何使用 tensorflow.train.Checkpoint.restore 恢复 saved/checkpointed 网络。

我正在使用强烈基于 Google 的 Colab 教程创建 pix2pix GAN 的代码。下面,我摘录了关键部分,它只是尝试实例化一个新网络,然后用保存和检查点的先前网络的权重填充它。

我通过对网络的所有权重求和来为网络的特定实例分配一个唯一的(大概)ID 号。我在创建网络时和尝试恢复检查点网络后都比较了这些 ID 号

def main(opt):

    # Initialize pix2pix GAN using arguments input from command line
    p2p = Pix2Pix(vars(opt))
    print(opt)

    # print sum of initial weights for net
    print("Init Model Weights:", 
           sum([x.numpy().sum() for x in p2p.generator.weights]))

    # Create or read from model checkpoints
    checkpoint = tf.train.Checkpoint(generator_optimizer=p2p.generator_optimizer,
                                     discriminator_optimizer=p2p.discriminator_optimizer,
                                     generator=p2p.generator,
                                     discriminator=p2p.discriminator)
    
    # print sum of weights from checkpoint, to ensure it has access 
    # to relevant regions of p2p
    print("Checkpoint Weights:", 
           sum([x.numpy().sum() for x in checkpoint.generator.weights]))

    # Recover Checkpointed net
    checkpoint.restore(tf.train.latest_checkpoint(opt.weights)).expect_partial()

    # print sum of weights for p2p & checkpoint after attempting to restore saved net 
    print("Restore Model Weights:", 
           sum([x.numpy().sum() for x in p2p.generator.weights]))
    print("Restored Checkpoint Weights:", 
           sum([x.numpy().sum() for x in checkpoint.generator.weights]))
    print("Done.")

if __name__ == '__main__':
    opt = parse_opt()
    main(opt)

我运行这段代码得到的输出如下:

Namespace(channels='1', data='data', img_size=256, output='output', weights='weights/ckpt-40.data-00000-of-00001')
## These are the input arguments, the images have only 1 channel (they're gray scale)
## The directory with data is ./data, the images are 265x256
## The output directory is ./output
## The checkpointed net is stored in ./weights/ckpt-40.data-00000-of-00001


## Sums of nets' weights
Init Model Weights: 11047.206374436617
Checkpoint Weights: 11047.206374436617
Restore Model Weights: 11047.206374436617
Restored Checkpoint Weights: 11047.206374436617

Done.

虽然 p2pcheckpoint 似乎可以访问内存中的相同位置,但在恢复检查点版本之前和之后网络的权重总和没有变化。

为什么我没有恢复保存的网络?

我的替代方法是使用回调和恢复,您可以为他们确定的检查点命名图层。

示例:

"""""""""""""""""""""""""""""""""""""""""""""""""""""""""
: DataSet
"""""""""""""""""""""""""""""""""""""""""""""""""""""""""
DATA = adding_array_DATA(DATA, action, reward, gamescores, step)

dataset = tf.data.Dataset.from_tensor_slices((tf.constant(DATA, dtype=tf.float32),tf.constant(np.reshape(0, (1, 1, 1, 1)))))
batched_features = dataset

"""""""""""""""""""""""""""""""""""""""""""""""""""""""""
: Model Initialize
"""""""""""""""""""""""""""""""""""""""""""""""""""""""""
model = tf.keras.models.Sequential([
    tf.keras.layers.InputLayer(input_shape=(1200, 1)),
    tf.keras.layers.Bidirectional(tf.keras.layers.LSTM(128, return_sequences=True, return_state=False)),
    tf.keras.layers.Bidirectional(tf.keras.layers.LSTM(128)),
])
        
model.add(layers.Flatten())
model.add(layers.Dense(64))
model.add(layers.Dense(2))
model.summary()

"""""""""""""""""""""""""""""""""""""""""""""""""""""""""
: Callback
"""""""""""""""""""""""""""""""""""""""""""""""""""""""""
cp_callback = tf.keras.callbacks.ModelCheckpoint(checkpoint_dir, monitor='val_loss', 
                                verbose=0, save_best_only=True, mode='min' )
                                
"""""""""""""""""""""""""""""""""""""""""""""""""""""""""
: Optimizer
"""""""""""""""""""""""""""""""""""""""""""""""""""""""""
optimizer = tf.keras.optimizers.Nadam(
    learning_rate=0.0001, beta_1=0.9, beta_2=0.999, epsilon=1e-07,
    name='Nadam'
) # 0.00001

"""""""""""""""""""""""""""""""""""""""""""""""""""""""""
: Loss Fn
"""""""""""""""""""""""""""""""""""""""""""""""""""""""""                               
# 1
lossfn = tf.keras.losses.MeanSquaredLogarithmicError(reduction=tf.keras.losses.Reduction.AUTO, name='mean_squared_logarithmic_error')
# 2
# lossfn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
"""""""""""""""""""""""""""""""""""""""""""""""""""""""""
: Model Summary
"""""""""""""""""""""""""""""""""""""""""""""""""""""""""
model.compile(optimizer=optimizer, loss=lossfn, metrics=['accuracy'])

"""""""""""""""""""""""""""""""""""""""""""""""""""""""""
: Training
"""""""""""""""""""""""""""""""""""""""""""""""""""""""""
history = model.fit(batched_features, epochs=1 ,validation_data=(batched_features), callbacks=[cp_callback]) # epochs=500 # , callbacks=[cp_callback, tb_callback]

checkpoint = tf.train.Checkpoint(model)
checkpoint.restore(checkpoint_dir)

input('...')

输出:

2022-03-08 10:33:06.965274: I tensorflow/stream_executor/cuda/cuda_dnn.cc:368] Loaded cuDNN version 8100
1/1 [==============================] - ETA: 0s - **loss: 0.0154** - accuracy: 0.0000e+002022-03-08 10:33:16.175845: W tensorflow/python/util/util.cc:368] Sets are not currently considered sequences, but this may change in the future, so consider avoiding using them.
1/1 [==============================] - 31s 31s/step - **loss: 0.0154** - accuracy: 0.0000e+00 - val_loss: 0.0074 - val_accuracy: 0.0000e+00
...

出现问题是因为tf.Checkpoint.restore需要检查点网络的存储目录,而不是特定文件(或者,我认为是特定文件-./weights/ckpt-40.data- 00000-of-00001)

当没有给它一个有效的目录时,它会默默地继续下一行代码,不会更新网络或抛出错误。解决方法是为它提供包含相关检查点文件的目录,而不仅仅是我认为相关的文件。