使用 Estimator API 更新 batch_normalization 均值和方差
Updating batch_normalization mean & variance using Estimator API
文档对此并非 100% 清楚:
Note: when training, the moving_mean and moving_variance need to be updated. By default the update ops are placed in tf.GraphKeys.UPDATE_OPS, so they need to be added as a dependency to the train_op. For example:
(参见 https://www.tensorflow.org/api_docs/python/tf/layers/batch_normalization)
这是否意味着保存 moving_mean
和 moving_variance
所需的全部内容如下?
def model_fn(features, labels, mode, params):
training = mode == tf.estimator.ModeKeys.TRAIN
extra_update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
x = tf.reshape(features, [-1, 64, 64, 3])
x = tf.layers.batch_normalization(x, training=training)
# ...
with tf.control_dependencies(extra_update_ops):
train_op = optimizer.minimize(loss=loss, global_step=tf.train.get_global_step())
换句话说,就是简单地使用
with tf.control_dependencies(extra_update_ops):
注意保存 moving_mean
和 moving_variance
?
是的,添加那些控制依赖项将保存均值和方差。
事实证明,这些值 可以 自动保存。极端情况是,如果您在将批归一化操作添加到图形之前获取更新操作集合,则更新集合将为空。以前没有记录,但现在记录了。
使用 batch_norm 的注意事项是在调用 tf.layers.batch_normalization
之后调用 tf.get_collection(tf.GraphKeys.UPDATE_OPS)
。
文档对此并非 100% 清楚:
Note: when training, the moving_mean and moving_variance need to be updated. By default the update ops are placed in tf.GraphKeys.UPDATE_OPS, so they need to be added as a dependency to the train_op. For example:
(参见 https://www.tensorflow.org/api_docs/python/tf/layers/batch_normalization)
这是否意味着保存 moving_mean
和 moving_variance
所需的全部内容如下?
def model_fn(features, labels, mode, params):
training = mode == tf.estimator.ModeKeys.TRAIN
extra_update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
x = tf.reshape(features, [-1, 64, 64, 3])
x = tf.layers.batch_normalization(x, training=training)
# ...
with tf.control_dependencies(extra_update_ops):
train_op = optimizer.minimize(loss=loss, global_step=tf.train.get_global_step())
换句话说,就是简单地使用
with tf.control_dependencies(extra_update_ops):
注意保存 moving_mean
和 moving_variance
?
是的,添加那些控制依赖项将保存均值和方差。
事实证明,这些值 可以 自动保存。极端情况是,如果您在将批归一化操作添加到图形之前获取更新操作集合,则更新集合将为空。以前没有记录,但现在记录了。
使用 batch_norm 的注意事项是在调用 tf.layers.batch_normalization
之后调用 tf.get_collection(tf.GraphKeys.UPDATE_OPS)
。