我如何使用 Tensorflow.Checkpoint 来恢复以前训练过的网络
How can I use Tensorflow.Checkpoint to recover a previously trained net
我正在尝试了解如何使用 tensorflow.train.Checkpoint.restore
恢复 saved/checkpointed 网络。
我正在使用强烈基于 Google 的 Colab 教程创建 pix2pix GAN 的代码。下面,我摘录了关键部分,它只是尝试实例化一个新网络,然后用保存和检查点的先前网络的权重填充它。
我通过对网络的所有权重求和来为网络的特定实例分配一个唯一的(大概)ID 号。我在创建网络时和尝试恢复检查点网络后都比较了这些 ID 号
def main(opt):
# Initialize pix2pix GAN using arguments input from command line
p2p = Pix2Pix(vars(opt))
print(opt)
# print sum of initial weights for net
print("Init Model Weights:",
sum([x.numpy().sum() for x in p2p.generator.weights]))
# Create or read from model checkpoints
checkpoint = tf.train.Checkpoint(generator_optimizer=p2p.generator_optimizer,
discriminator_optimizer=p2p.discriminator_optimizer,
generator=p2p.generator,
discriminator=p2p.discriminator)
# print sum of weights from checkpoint, to ensure it has access
# to relevant regions of p2p
print("Checkpoint Weights:",
sum([x.numpy().sum() for x in checkpoint.generator.weights]))
# Recover Checkpointed net
checkpoint.restore(tf.train.latest_checkpoint(opt.weights)).expect_partial()
# print sum of weights for p2p & checkpoint after attempting to restore saved net
print("Restore Model Weights:",
sum([x.numpy().sum() for x in p2p.generator.weights]))
print("Restored Checkpoint Weights:",
sum([x.numpy().sum() for x in checkpoint.generator.weights]))
print("Done.")
if __name__ == '__main__':
opt = parse_opt()
main(opt)
我运行这段代码得到的输出如下:
Namespace(channels='1', data='data', img_size=256, output='output', weights='weights/ckpt-40.data-00000-of-00001')
## These are the input arguments, the images have only 1 channel (they're gray scale)
## The directory with data is ./data, the images are 265x256
## The output directory is ./output
## The checkpointed net is stored in ./weights/ckpt-40.data-00000-of-00001
## Sums of nets' weights
Init Model Weights: 11047.206374436617
Checkpoint Weights: 11047.206374436617
Restore Model Weights: 11047.206374436617
Restored Checkpoint Weights: 11047.206374436617
Done.
虽然 p2p
和 checkpoint
似乎可以访问内存中的相同位置,但在恢复检查点版本之前和之后网络的权重总和没有变化。
为什么我没有恢复保存的网络?
我的替代方法是使用回调和恢复,您可以为他们确定的检查点命名图层。
示例:
"""""""""""""""""""""""""""""""""""""""""""""""""""""""""
: DataSet
"""""""""""""""""""""""""""""""""""""""""""""""""""""""""
DATA = adding_array_DATA(DATA, action, reward, gamescores, step)
dataset = tf.data.Dataset.from_tensor_slices((tf.constant(DATA, dtype=tf.float32),tf.constant(np.reshape(0, (1, 1, 1, 1)))))
batched_features = dataset
"""""""""""""""""""""""""""""""""""""""""""""""""""""""""
: Model Initialize
"""""""""""""""""""""""""""""""""""""""""""""""""""""""""
model = tf.keras.models.Sequential([
tf.keras.layers.InputLayer(input_shape=(1200, 1)),
tf.keras.layers.Bidirectional(tf.keras.layers.LSTM(128, return_sequences=True, return_state=False)),
tf.keras.layers.Bidirectional(tf.keras.layers.LSTM(128)),
])
model.add(layers.Flatten())
model.add(layers.Dense(64))
model.add(layers.Dense(2))
model.summary()
"""""""""""""""""""""""""""""""""""""""""""""""""""""""""
: Callback
"""""""""""""""""""""""""""""""""""""""""""""""""""""""""
cp_callback = tf.keras.callbacks.ModelCheckpoint(checkpoint_dir, monitor='val_loss',
verbose=0, save_best_only=True, mode='min' )
"""""""""""""""""""""""""""""""""""""""""""""""""""""""""
: Optimizer
"""""""""""""""""""""""""""""""""""""""""""""""""""""""""
optimizer = tf.keras.optimizers.Nadam(
learning_rate=0.0001, beta_1=0.9, beta_2=0.999, epsilon=1e-07,
name='Nadam'
) # 0.00001
"""""""""""""""""""""""""""""""""""""""""""""""""""""""""
: Loss Fn
"""""""""""""""""""""""""""""""""""""""""""""""""""""""""
# 1
lossfn = tf.keras.losses.MeanSquaredLogarithmicError(reduction=tf.keras.losses.Reduction.AUTO, name='mean_squared_logarithmic_error')
# 2
# lossfn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
"""""""""""""""""""""""""""""""""""""""""""""""""""""""""
: Model Summary
"""""""""""""""""""""""""""""""""""""""""""""""""""""""""
model.compile(optimizer=optimizer, loss=lossfn, metrics=['accuracy'])
"""""""""""""""""""""""""""""""""""""""""""""""""""""""""
: Training
"""""""""""""""""""""""""""""""""""""""""""""""""""""""""
history = model.fit(batched_features, epochs=1 ,validation_data=(batched_features), callbacks=[cp_callback]) # epochs=500 # , callbacks=[cp_callback, tb_callback]
checkpoint = tf.train.Checkpoint(model)
checkpoint.restore(checkpoint_dir)
input('...')
输出:
2022-03-08 10:33:06.965274: I tensorflow/stream_executor/cuda/cuda_dnn.cc:368] Loaded cuDNN version 8100
1/1 [==============================] - ETA: 0s - **loss: 0.0154** - accuracy: 0.0000e+002022-03-08 10:33:16.175845: W tensorflow/python/util/util.cc:368] Sets are not currently considered sequences, but this may change in the future, so consider avoiding using them.
1/1 [==============================] - 31s 31s/step - **loss: 0.0154** - accuracy: 0.0000e+00 - val_loss: 0.0074 - val_accuracy: 0.0000e+00
...
出现问题是因为tf.Checkpoint.restore需要检查点网络的存储目录,而不是特定文件(或者,我认为是特定文件-./weights/ckpt-40.data- 00000-of-00001)
当没有给它一个有效的目录时,它会默默地继续下一行代码,不会更新网络或抛出错误。解决方法是为它提供包含相关检查点文件的目录,而不仅仅是我认为相关的文件。
我正在尝试了解如何使用 tensorflow.train.Checkpoint.restore
恢复 saved/checkpointed 网络。
我正在使用强烈基于 Google 的 Colab 教程创建 pix2pix GAN 的代码。下面,我摘录了关键部分,它只是尝试实例化一个新网络,然后用保存和检查点的先前网络的权重填充它。
我通过对网络的所有权重求和来为网络的特定实例分配一个唯一的(大概)ID 号。我在创建网络时和尝试恢复检查点网络后都比较了这些 ID 号
def main(opt):
# Initialize pix2pix GAN using arguments input from command line
p2p = Pix2Pix(vars(opt))
print(opt)
# print sum of initial weights for net
print("Init Model Weights:",
sum([x.numpy().sum() for x in p2p.generator.weights]))
# Create or read from model checkpoints
checkpoint = tf.train.Checkpoint(generator_optimizer=p2p.generator_optimizer,
discriminator_optimizer=p2p.discriminator_optimizer,
generator=p2p.generator,
discriminator=p2p.discriminator)
# print sum of weights from checkpoint, to ensure it has access
# to relevant regions of p2p
print("Checkpoint Weights:",
sum([x.numpy().sum() for x in checkpoint.generator.weights]))
# Recover Checkpointed net
checkpoint.restore(tf.train.latest_checkpoint(opt.weights)).expect_partial()
# print sum of weights for p2p & checkpoint after attempting to restore saved net
print("Restore Model Weights:",
sum([x.numpy().sum() for x in p2p.generator.weights]))
print("Restored Checkpoint Weights:",
sum([x.numpy().sum() for x in checkpoint.generator.weights]))
print("Done.")
if __name__ == '__main__':
opt = parse_opt()
main(opt)
我运行这段代码得到的输出如下:
Namespace(channels='1', data='data', img_size=256, output='output', weights='weights/ckpt-40.data-00000-of-00001')
## These are the input arguments, the images have only 1 channel (they're gray scale)
## The directory with data is ./data, the images are 265x256
## The output directory is ./output
## The checkpointed net is stored in ./weights/ckpt-40.data-00000-of-00001
## Sums of nets' weights
Init Model Weights: 11047.206374436617
Checkpoint Weights: 11047.206374436617
Restore Model Weights: 11047.206374436617
Restored Checkpoint Weights: 11047.206374436617
Done.
虽然 p2p
和 checkpoint
似乎可以访问内存中的相同位置,但在恢复检查点版本之前和之后网络的权重总和没有变化。
为什么我没有恢复保存的网络?
我的替代方法是使用回调和恢复,您可以为他们确定的检查点命名图层。
示例:
"""""""""""""""""""""""""""""""""""""""""""""""""""""""""
: DataSet
"""""""""""""""""""""""""""""""""""""""""""""""""""""""""
DATA = adding_array_DATA(DATA, action, reward, gamescores, step)
dataset = tf.data.Dataset.from_tensor_slices((tf.constant(DATA, dtype=tf.float32),tf.constant(np.reshape(0, (1, 1, 1, 1)))))
batched_features = dataset
"""""""""""""""""""""""""""""""""""""""""""""""""""""""""
: Model Initialize
"""""""""""""""""""""""""""""""""""""""""""""""""""""""""
model = tf.keras.models.Sequential([
tf.keras.layers.InputLayer(input_shape=(1200, 1)),
tf.keras.layers.Bidirectional(tf.keras.layers.LSTM(128, return_sequences=True, return_state=False)),
tf.keras.layers.Bidirectional(tf.keras.layers.LSTM(128)),
])
model.add(layers.Flatten())
model.add(layers.Dense(64))
model.add(layers.Dense(2))
model.summary()
"""""""""""""""""""""""""""""""""""""""""""""""""""""""""
: Callback
"""""""""""""""""""""""""""""""""""""""""""""""""""""""""
cp_callback = tf.keras.callbacks.ModelCheckpoint(checkpoint_dir, monitor='val_loss',
verbose=0, save_best_only=True, mode='min' )
"""""""""""""""""""""""""""""""""""""""""""""""""""""""""
: Optimizer
"""""""""""""""""""""""""""""""""""""""""""""""""""""""""
optimizer = tf.keras.optimizers.Nadam(
learning_rate=0.0001, beta_1=0.9, beta_2=0.999, epsilon=1e-07,
name='Nadam'
) # 0.00001
"""""""""""""""""""""""""""""""""""""""""""""""""""""""""
: Loss Fn
"""""""""""""""""""""""""""""""""""""""""""""""""""""""""
# 1
lossfn = tf.keras.losses.MeanSquaredLogarithmicError(reduction=tf.keras.losses.Reduction.AUTO, name='mean_squared_logarithmic_error')
# 2
# lossfn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
"""""""""""""""""""""""""""""""""""""""""""""""""""""""""
: Model Summary
"""""""""""""""""""""""""""""""""""""""""""""""""""""""""
model.compile(optimizer=optimizer, loss=lossfn, metrics=['accuracy'])
"""""""""""""""""""""""""""""""""""""""""""""""""""""""""
: Training
"""""""""""""""""""""""""""""""""""""""""""""""""""""""""
history = model.fit(batched_features, epochs=1 ,validation_data=(batched_features), callbacks=[cp_callback]) # epochs=500 # , callbacks=[cp_callback, tb_callback]
checkpoint = tf.train.Checkpoint(model)
checkpoint.restore(checkpoint_dir)
input('...')
输出:
2022-03-08 10:33:06.965274: I tensorflow/stream_executor/cuda/cuda_dnn.cc:368] Loaded cuDNN version 8100
1/1 [==============================] - ETA: 0s - **loss: 0.0154** - accuracy: 0.0000e+002022-03-08 10:33:16.175845: W tensorflow/python/util/util.cc:368] Sets are not currently considered sequences, but this may change in the future, so consider avoiding using them.
1/1 [==============================] - 31s 31s/step - **loss: 0.0154** - accuracy: 0.0000e+00 - val_loss: 0.0074 - val_accuracy: 0.0000e+00
...
出现问题是因为tf.Checkpoint.restore需要检查点网络的存储目录,而不是特定文件(或者,我认为是特定文件-./weights/ckpt-40.data- 00000-of-00001)
当没有给它一个有效的目录时,它会默默地继续下一行代码,不会更新网络或抛出错误。解决方法是为它提供包含相关检查点文件的目录,而不仅仅是我认为相关的文件。