使用 Estimator API 更新 batch_normalization 均值和方差

Updating batch_normalization mean & variance using Estimator API

文档对此并非 100% 清楚:

Note: when training, the moving_mean and moving_variance need to be updated. By default the update ops are placed in tf.GraphKeys.UPDATE_OPS, so they need to be added as a dependency to the train_op. For example:

(参见 https://www.tensorflow.org/api_docs/python/tf/layers/batch_normalization

这是否意味着保存 moving_meanmoving_variance 所需的全部内容如下?

def model_fn(features, labels, mode, params):
   training = mode == tf.estimator.ModeKeys.TRAIN
   extra_update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)

   x = tf.reshape(features, [-1, 64, 64, 3])
   x = tf.layers.batch_normalization(x, training=training)

   # ...

  with tf.control_dependencies(extra_update_ops):
     train_op = optimizer.minimize(loss=loss, global_step=tf.train.get_global_step())

换句话说,就是简单地使用

with tf.control_dependencies(extra_update_ops):

注意保存 moving_meanmoving_variance?

是的,添加那些控制依赖项将保存均值和方差。

事实证明,这些值 可以 自动保存。极端情况是,如果您在将批归一化操作添加到图形之前获取更新操作集合,则更新集合将为空。以前没有记录,但现在记录了。

使用 batch_norm 的注意事项是在调用 tf.layers.batch_normalization 之后调用 tf.get_collection(tf.GraphKeys.UPDATE_OPS)