如何在没有显式 model.fit 的情况下设置 tf.keras.callbacks.ModelCheckpoint?
How to set tf.keras.callbacks.ModelCheckpoint without explicit model.fit?
我想在代码的转换过程中添加Checkpoint
。我知道在model.fit
中设置callbacks = callbacks
的方法。但是,在代码中,没有通过 K.function
显式调用 model.fit
insead,如下所示。谁能告诉我设置检查点的正确位置在哪里?完整代码可以通过this github link.
查看
vae_model = vae_util.create_vae(input_shape)
vae_model.compile(optimizer=opt, loss='mse')
rec_loss = vae_loss(vae_model.output, train_target)
total_loss = rec_loss
updates = opt.get_updates(total_loss, vae_model.trainable_weights)
iterate = K.function(vae_model.inputs + [train_target], [rec_loss], updates=updates)
eval_rec_loss = vae_loss(vae_model.output, test_target)
evaluate = K.function(vae_model.inputs + [test_target], [eval_rec_loss])
原代码在line 139
中已有安全点保存
我想在代码的转换过程中添加Checkpoint
。我知道在model.fit
中设置callbacks = callbacks
的方法。但是,在代码中,没有通过 K.function
显式调用 model.fit
insead,如下所示。谁能告诉我设置检查点的正确位置在哪里?完整代码可以通过this github link.
vae_model = vae_util.create_vae(input_shape)
vae_model.compile(optimizer=opt, loss='mse')
rec_loss = vae_loss(vae_model.output, train_target)
total_loss = rec_loss
updates = opt.get_updates(total_loss, vae_model.trainable_weights)
iterate = K.function(vae_model.inputs + [train_target], [rec_loss], updates=updates)
eval_rec_loss = vae_loss(vae_model.output, test_target)
evaluate = K.function(vae_model.inputs + [test_target], [eval_rec_loss])
原代码在line 139
中已有安全点保存