如何在没有显式 model.fit 的情况下设置 tf.keras.callbacks.ModelCheckpoint?

How to set tf.keras.callbacks.ModelCheckpoint without explicit model.fit?

我想在代码的转换过程中添加Checkpoint。我知道在model.fit中设置callbacks = callbacks的方法。但是,在代码中,没有通过 K.function 显式调用 model.fit insead,如下所示。谁能告诉我设置检查点的正确位置在哪里?完整代码可以通过this github link.

查看
vae_model = vae_util.create_vae(input_shape)
vae_model.compile(optimizer=opt, loss='mse')
rec_loss = vae_loss(vae_model.output, train_target)
total_loss = rec_loss
updates = opt.get_updates(total_loss, vae_model.trainable_weights)

iterate = K.function(vae_model.inputs + [train_target], [rec_loss], updates=updates)

eval_rec_loss = vae_loss(vae_model.output, test_target)
evaluate = K.function(vae_model.inputs + [test_target], [eval_rec_loss])   

原代码在line 139

中已有安全点保存