恢复训练好的tensorflow模型,编辑与节点关联的值,并保存

Restore trained tensorflow model, edit the value associated with a node, and save it

我用tensorflow训练了一个模型,在训练过程中使用了batch normalization。批量归一化要求用户传递一个布尔值,称为 is_training,以设置模型是处于训练阶段还是测试阶段。

训练模型时,is_training设置为常量如下图

is_training = tf.constant(True, dtype=tf.bool, name='is_training')

我已经保存了训练好的模型,文件包括检查点、.meta 文件、.index 文件和一个.data。我想恢复模型并使用它进行 运行 推理。 该模型无法重新训练。所以,我想恢复现有的模型,将 is_training 的值设置为 False ,然后将模型保存回来。 如何编辑与该节点关联的布尔值,并再次保存模型?

您可以使用 tf.train.import_meta_graphinput_map 参数将图形张量重新映射到更新后的值。

config = tf.ConfigProto(allow_soft_placement=True)
with tf.Session(config=config) as sess:
    # define the new is_training tensor
    is_training = tf.constant(False, dtype=tf.bool, name='is_training')

    # now import the graph using the .meta file of the checkpoint
    saver = tf.train.import_meta_graph(
    '/path/to/model.meta', input_map={'is_training:0':is_training})

    # restore all weights using the model checkpoint 
    saver.restore(sess, '/path/to/model')

    # save updated graph and variables values
    saver.save(sess, '/path/to/new-model-name')