无法在 TensorFlow 2 中加载模型权重

Cannot load model weights in TensorFlow 2

在 TensorFlow 2.2 中保存模型权重后,我无法加载它们。权重似乎保存正确(我认为),但是,我无法加载预训练模型。

我当前的代码是:

segmentor = sequential_model_1()
discriminator = sequential_model_2()

def save_model(ckp_dir):
    # create directory, if it does not exist:
    utils.safe_mkdir(ckp_dir)

    # save weights
    segmentor.save_weights(os.path.join(ckp_dir, 'checkpoint-segmentor'))
    discriminator.save_weights(os.path.join(ckp_dir, 'checkpoint-discriminator'))

def load_pretrained_model(ckp_dir):
    try:
        segmentor.load_weights(os.path.join(ckp_dir, 'checkpoint-segmentor'), skip_mismatch=True)
        discriminator.load_weights(os.path.join(ckp_dir, 'checkpoint-discriminator'), skip_mismatch=True)
        print('Loading pre-trained model from: {0}'.format(ckp_dir))
    except ValueError:
        print('No pre-trained model available.')

然后我有训练循环:

# training loop:
for epoch in range(num_epochs):

    for image, label in dataset:
        train_step()

    # save best model I find during training:
    if this_is_the_best_model_on_validation_set():
        save_model(ckp_dir='logs_dir')

然后,在训练结束时 "for loop",我想加载最佳模型并用它进行测试。因此,我 运行:

# load saved model and do a test:
load_pretrained_model(ckp_dir='logs_dir')
test()

然而,这会导致 ValueError。我检查了应该保存权重的目录,它们就在那里!

知道我的代码有什么问题吗?我是否错误地加载了重量?

谢谢!

好的,这是您的问题 - 您的 try-except 块掩盖了真正的问题。删除它会得到 ValueError:

ValueError: When calling model.load_weights, skip_mismatch can only be set to True when by_name is True.

有两种方法可以缓解这种情况 - 您可以使用 by_name=True 调用 load_weights,或者根据需要删除 skip_mismatch=True。在测试您的代码时,这两种情况都适合我。

另一个考虑因素是,当您将鉴别器和分段器检查点都存储到日志目录时,您每次都会覆盖 checkpoint 文件。这包含两个字符串,它们提供特定模型检查点文件的路径。由于您将鉴别器保存在第二位,因此每次此文件都会说鉴别器而不引用分段器。您可以通过将每个模型存储在日志目录中的两个子目录中来缓解这种情况,即

logs_dir/
    + discriminator/
        + checkpoint
        + ...
    + segmentor/
        + checkpoint
        + ...

虽然在当前状态下您的代码在这种情况下可以工作。