如何使用 Keras 作为高级 API 在 tensorflow 上实现批量归一化

How to implement Batch Normalization on tensorflow with Keras as a high-level API

BatchNormalization (BN) 在训练和推理时的操作略有不同。在训练中,它使用当前小批量的平均值和方差来缩放其输入;这意味着应用批归一化的确切结果不仅取决于当前输入,还取决于小批量的所有其他元素。在我们想要确定性结果的推理模式下,这显然是不可取的。因此,在这种情况下,将使用整个训练集的全局平均值和方差的固定统计量。

在 Tensorflow 中,此行为由布尔开关 training 控制,需要在调用层时指定,请参阅 https://www.tensorflow.org/api_docs/python/tf/keras/layers/BatchNormalization。使用 Keras high-level API 时如何处理这个开关?我是否正确地假设它是自动处理的,这取决于我们使用的是 model.fit(x, ...) 还是 model.predict(x, ...)


为了测试这个,我写了这个例子。我们从随机分布开始,我们想要对输入是正数还是负数进行分类。然而,我们还有一个来自不同分布的测试数据集,其中输入被 2 取代(因此标签检查是否 x>2)。

import numpy as np
from math import ceil
from tensorflow.python.data import Dataset
from tensorflow.python.keras import Input, Model
from tensorflow.python.keras.layers import Dense, BatchNormalization

np.random.seed(18)
xt = np.random.randn(10_000, 1)
yt = np.array([[int(x > 0)] for x in xt])
train_data = Dataset.from_tensor_slices((xt, yt)).shuffle(10_000).repeat().batch(32).prefetch(2)

xv = np.random.randn(100, 1)
yv = np.array([[int(x > 0)] for x in xv])
valid_data = Dataset.from_tensor_slices((xv, yv)).repeat().batch(32).prefetch(2)

xs = np.random.randn(100, 1) + 2
ys = np.array([[int(x > 2)] for x in xs])
test_data = Dataset.from_tensor_slices((xs, ys)).repeat().batch(32).prefetch(2)

x = Input(shape=(1,))
a = BatchNormalization()(x)
a = Dense(8, activation='sigmoid')(a)
a = BatchNormalization()(a)
y = Dense(1, activation='sigmoid')(a)
model = Model(inputs=x, outputs=y, )
model.summary()

model.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy'])
model.fit(train_data, epochs=10, steps_per_epoch=ceil(10_000 / 32), validation_data=valid_data,
          validation_steps=ceil(100 / 32))
zs = model.predict(test_data, steps=ceil(100 / 32))
print(sum([ys[i] == int(zs[i] > 0.5) for i in range(100)]))

运行 代码打印值 0.5,这意味着一半的示例被正确标记。如果系统使用训练集上的全局统计信息来实现 BN,这就是我所期望的。

如果我们将 BN 层更改为读取

x = Input(shape=(1,))
a = BatchNormalization()(x, training=True)
a = Dense(8, activation='sigmoid')(a)
a = BatchNormalization()(a, training=True)
y = Dense(1, activation='sigmoid')(a)

和 运行 我们再次找到代码 0.87。始终强制训练状态,正确预测的百分比发生了变化。这与 model.predict(x, ...) 现在使用 mini-batch 的统计数据来实现 BN 的想法是一致的,因此能够稍微 "correct" 训练和测试数据之间的源分布不匹配。

对吗?

如果我对你的问题的理解正确,那么是的,keras 会根据 fitpredict/evaluate 自动管理训练与推理行为。该标志称为 learning_phase,它决定了 batch norm、dropout 和其他潜在事物的行为。当前学习阶段可以用keras.backend.learning_phase()查看,用keras.backend.set_learning_phase()设置。

https://keras.io/backend/#learning_phase