我正在尝试从某个检查点 (Tensorflow) 恢复训练,因为我正在使用 Colab 并且 12 小时不够
I am trying to resume training from a certain checkpoint (Tensorflow) because I'm using Colab and 12 hours aren't enough
这是我正在使用的部分代码
checkpoint_dir = 'training_checkpoints1'
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")
checkpoint = tf.train.Checkpoint(optimizer=optimizer,
encoder=encoder,
decoder=decoder)
下面是训练部分
EPOCHS = 900
for epoch in range(EPOCHS):
start = time.time()
hidden = encoder.initialize_hidden_state()
total_loss = 0
for (batch, (inp, targ)) in enumerate(dataset):
loss = 0
with tf.GradientTape() as tape:
enc_output, enc_hidden = encoder(inp, hidden)
dec_hidden = enc_hidden
dec_input = tf.expand_dims([targ_lang.word2idx['<start>']] * batch_size, 1)
# Teacher forcing - feeding the target as the next input
for t in range(1, targ.shape[1]):
# passing enc_output to the decoder
predictions, dec_hidden, _ = decoder(dec_input, dec_hidden, enc_output)
loss += loss_function(targ[:, t], predictions)
# using teacher forcing
dec_input = tf.expand_dims(targ[:, t], 1)
batch_loss = (loss / int(targ.shape[1]))
total_loss += batch_loss
variables = encoder.variables + decoder.variables
gradients = tape.gradient(loss, variables)
optimizer.apply_gradients(zip(gradients, variables))
if batch % 100 == 0:
print('Epoch {} Batch {} Loss {:.4f}'.format(epoch + 1,
batch,
batch_loss.numpy()))
# saving (checkpoint) the model every 2 epochs
if (epoch + 1) % 2 == 0:
checkpoint.save(file_prefix = checkpoint_prefix)
print('Epoch {} Loss {:.4f}'.format(epoch + 1,
total_loss / num_batches))
print('Time taken for 1 epoch {} sec\n'.format(time.time() - start))
现在我想为 exp 恢复这个检查点并从那里开始训练,但我不知道如何。
path="/content/drive/My Drive/training_checkpoints1/ckpt-9"
checkpoint.restore(path)
结果
<tensorflow.python.training.tracking.util.CheckpointLoadStatus at 0x7f6653263048>
您应该在开始时创建一个 CheckpointManager 作为:
checkpoint_path = os.path.abspath('.') + "/checkpoints" # Put your path here
ckpt = tf.train.Checkpoint(encoder=encoder,
decoder=decoder,
optimizer = optimizer)
ckpt_manager = tf.train.CheckpointManager(ckpt, checkpoint_path, max_to_keep=5)
现在在 运行 几个 epoch 之后,要恢复最新的检查点,您应该从 CheckpointManager
:
获取最新的检查点
start_epoch = 0
if ckpt_manager.latest_checkpoint:
start_epoch = int(ckpt_manager.latest_checkpoint.split('-')[-1])
# restoring the latest checkpoint in checkpoint_path
ckpt.restore(ckpt_manager.latest_checkpoint)
这将从最新的纪元恢复您的会话。
这是我正在使用的部分代码
checkpoint_dir = 'training_checkpoints1'
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")
checkpoint = tf.train.Checkpoint(optimizer=optimizer,
encoder=encoder,
decoder=decoder)
下面是训练部分
EPOCHS = 900
for epoch in range(EPOCHS):
start = time.time()
hidden = encoder.initialize_hidden_state()
total_loss = 0
for (batch, (inp, targ)) in enumerate(dataset):
loss = 0
with tf.GradientTape() as tape:
enc_output, enc_hidden = encoder(inp, hidden)
dec_hidden = enc_hidden
dec_input = tf.expand_dims([targ_lang.word2idx['<start>']] * batch_size, 1)
# Teacher forcing - feeding the target as the next input
for t in range(1, targ.shape[1]):
# passing enc_output to the decoder
predictions, dec_hidden, _ = decoder(dec_input, dec_hidden, enc_output)
loss += loss_function(targ[:, t], predictions)
# using teacher forcing
dec_input = tf.expand_dims(targ[:, t], 1)
batch_loss = (loss / int(targ.shape[1]))
total_loss += batch_loss
variables = encoder.variables + decoder.variables
gradients = tape.gradient(loss, variables)
optimizer.apply_gradients(zip(gradients, variables))
if batch % 100 == 0:
print('Epoch {} Batch {} Loss {:.4f}'.format(epoch + 1,
batch,
batch_loss.numpy()))
# saving (checkpoint) the model every 2 epochs
if (epoch + 1) % 2 == 0:
checkpoint.save(file_prefix = checkpoint_prefix)
print('Epoch {} Loss {:.4f}'.format(epoch + 1,
total_loss / num_batches))
print('Time taken for 1 epoch {} sec\n'.format(time.time() - start))
现在我想为 exp 恢复这个检查点并从那里开始训练,但我不知道如何。
path="/content/drive/My Drive/training_checkpoints1/ckpt-9"
checkpoint.restore(path)
结果
<tensorflow.python.training.tracking.util.CheckpointLoadStatus at 0x7f6653263048>
您应该在开始时创建一个 CheckpointManager 作为:
checkpoint_path = os.path.abspath('.') + "/checkpoints" # Put your path here
ckpt = tf.train.Checkpoint(encoder=encoder,
decoder=decoder,
optimizer = optimizer)
ckpt_manager = tf.train.CheckpointManager(ckpt, checkpoint_path, max_to_keep=5)
现在在 运行 几个 epoch 之后,要恢复最新的检查点,您应该从 CheckpointManager
:
start_epoch = 0
if ckpt_manager.latest_checkpoint:
start_epoch = int(ckpt_manager.latest_checkpoint.split('-')[-1])
# restoring the latest checkpoint in checkpoint_path
ckpt.restore(ckpt_manager.latest_checkpoint)
这将从最新的纪元恢复您的会话。