Tensorflow save/restore 批量规范

Tensorflow save/restore batch norm

我在 Tensorflow 中训练了一个带有批量规范的模型。我想保存模型并将其恢复以供进一步使用。批量规范由

完成
def batch_norm(input, phase):
    return tf.layers.batch_normalization(input, training=phase)

其中阶段在训练期间为 True,在测试期间为 False

好像只是调用

saver = tf.train.Saver()
saver.save(sess, savedir + "ckpt")

效果不佳,因为当我恢复模型时,它首先显示恢复成功。它还说 Attempting to use uninitialized value batch_normalization_585/beta 如果我只是 运行 图中的一个节点。这是否与未正确保存模型或我错过的其他内容有关?

不确定这是否需要解释,但以防万一(以及其他潜在观众)。

每当您在 TensorFlow 中创建操作时,都会将一个新节点添加到图中。图中的两个节点不能具有相同的名称。您可以定义您创建的任何节点的名称,但如果您不提供名称,TensorFlow 将以一种确定性的方式为您选择一个(也就是说,不是随机的,而是总是以相同的顺序)。如果你把两个数字相加,它可能是 Add,但是如果你再做一个相加,因为没有两个节点可以同名,它可能是 Add_2。一旦在图中创建了一个节点,它的名称就不能更改。许多功能依次创建几个子节点;例如,tf.layers.batch_normalization 创建一些内部变量 betagamma.

按以下方式保存和恢复作品:

  1. 您创建一个代表您想要的模型的图表。此图包含将由保存器保存的变量。
  2. 你对该图进行初始化、训练或做任何你想做的事情,模型中的变量会被分配一些值。
  3. 您在保存器上调用 save 来将变量的值保存到文件中。
  4. 现在您在另一个图表 中重新创建模型(它可以是一个完全不同的 Python 会话,也可以只是与第一个图表共存的另一个图表)。模型的创建方式必须与第一个完全相同。
  5. 您在保存程序上调用 restore 以检索变量的值。

为了使其工作,第一张和第二张图中的变量名称必须完全相同

在您的示例中,TensorFlow 抱怨变量 batch_normalization_585/beta。您似乎在同一张图中调用了 tf.layers.batch_normalization 将近 600 次,因此您有那么多 beta 变量。我怀疑你真的需要那么多,所以我猜你只是在试验 API 并最终得到那么多副本。

这是一份应该有效的草稿:

import tensorflow as tf

def make_model():
    input = tf.placeholder(...)
    phase = tf.placeholder(...)
    input_norm = tf.layers.batch_normalization(input, training=phase))
    # Do some operations with input_norm
    output = ...
    saver = tf.train.Saver()
    return input, output, phase, saver

# We work with one graph first
g1 = tf.Graph()
with g1.as_default():
    input, output, phase, saver = make_model()
    with tf.Session() as sess:
        # Do your training or whatever...
        saver.save(sess, savedir + "ckpt")

# We work with a second different graph now
g2 = tf.Graph()
with g2.as_default():
    input, output, phase, saver = make_model()
    with tf.Session() as sess:
        saver.restore(sess, savedir + "ckpt")
        # Continue using your model...

同样,典型的情况不是并排放置两个图表,而是一个图表,然后在另一个 Python 会话中重新创建它,但最终两者都是一样的。重要的是,在这两种情况下,模型的创建方式相同(因此具有相同的节点名称)。

我也有 "Attempting to use uninitialized value batch_normalization_585/beta" 错误。这是因为通过像这样用空括号声明保护程序:

         saver = tf.train.Saver() 

saver 将保存 tf.trainable_variables() 中包含的变量,这些变量不包含 batch normalization 的移动平均值。要将此变量包含到保存的 ckpt 中,您需要执行以下操作:

         saver = tf.train.Saver(tf.global_variables())

其中保存了所有的变量,所以非常耗内存。或者您必须确定具有移动平均值或方差的变量并通过声明它们来保存它们:

         saver = tf.train.Saver(tf.trainable_variables() + list_of_extra_variables)