Keras 中的批量归一化

BatchNormalization in Keras

如何在 keras BatchNormalization 中更新移动均值和移动方差?

我在 tensorflow 文档中找到了这个,但我不知道放在哪里 train_op 或者如何使用 keras 模型:

update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
        with tf.control_dependencies(update_ops):
            train_op = optimizer.minimize( loss )

我找不到任何帖子说明如何使用 train_op 以及是否可以在 model.compile 中使用它。

如果您只需要用一些新值更新现有模型的权重,那么您可以执行以下操作:

w = model.get_layer('batchnorm_layer_name').get_weights()
# Order: [gamma, beta, mean, std]
for j in range(len(w[0])):
    gamma = w[0][j]
    beta = w[1][j]
    run_mean = w[2][j]
    run_std = w[3][j]
    w[2][j] = new_run_mean_value1
    w[3][j] = new_run_std_value2

model.get_layer('batchnorm_layer_name').set_weights(w)

如果您使用 BatchNormalization 层,则无需手动更新移动均值和方差。 Keras 负责在训练期间更新这些参数,并在测试期间保持它们固定(通过使用 model.predictmodel.evaluate 函数,与 model.fit_generator 和朋友一样)。

Keras 还跟踪学习阶段,因此训练期间的代码路径不同 运行 和 validation/testing。

这个问题有两种解释:第一种是假设目标是使用高水平训练 api,这个问题由 Matias Valdenegro 回答。

第二个——如评论中所讨论的——是是否可以使用标准张量流优化器进行批量归一化,如此处 keras a simplified tensorflow interface 和部分 "Collecting trainable weights and state updates" 所述。正如那里提到的,更新操作可以在 layer.updates 中访问,而不是在 tf.GraphKeys.UPDATE_OPS 中访问,事实上,如果你在 tensorflow 中有一个 keras 模型,你可以使用标准的 tensorflow 优化器和批量规范化来优化,就像这样

update_ops  = model.updates
with tf.control_dependencies(update_ops):
     train_op = optimizer.minimize( loss )

然后使用 tensorflow 会话获取 train_op。为了区分批量归一化层的训练和评估模式,您需要提供 学习 keras 引擎的阶段状态(见 "Different behaviors during training and testing" 与上面给出的 tutorial page 相同)。这会像这样工作

... 
# train
lo, _ = tf_sess.run(fetches=[loss, train_step],
                    feed_dict={tf_batch_data: bd,
                               tf_batch_labels: bl,
                               tensorflow.keras.backend.learning_phase(): 1})

...

# eval
lo = tf_sess.run(fetches=[loss],
                    feed_dict={tf_batch_data: bd,
                               tf_batch_labels: bl,
                               tensorflow.keras.backend.learning_phase(): 0})

我在 tensorflow 1.12 中尝试过这个,它适用于包含批量归一化的模型。鉴于我现有的 tensorflow 代码和接近 tensorflow 2.0 版,我很想自己使用这种方法,但鉴于 tensorflow 文档中没有提到这种方法,我不确定这是否会得到长期支持,我最终决定不使用它并投入更多资金来更改代码以使用高级 api.