在 Tensorflow 2 中训练时改变 BatchNormalization 动量
Changing BatchNormalization momentum while training in Tensorflow 2
我希望批量标准化 运行 统计(均值和方差)在训练结束时收敛,这需要将批量标准化动量从某个初始值增加到 1.0。我设法使用自定义 Callback
来改变动量,但它仅在我的模型在急切模式下编译时才有效。玩具示例(它在纪元零之后设置 momentum=1.0
,因此 moving_mean
应该停止更新):
import tensorflow as tf # version 2.3.1
import tensorflow_datasets as tfds
ds_train, ds_test = tfds.load("mnist", split=["train", "test"], shuffle_files=True, as_supervised=True)
ds_train = ds_train.batch(128)
ds_test = ds_test.batch(128)
model = tf.keras.models.Sequential(
[
tf.keras.layers.Flatten(input_shape=(28, 28)),
tf.keras.layers.Dense(128),
tf.keras.layers.BatchNormalization(),
tf.keras.layers.ReLU(),
tf.keras.layers.Dense(10),
]
)
model.compile(
optimizer=tf.keras.optimizers.Adam(0.001),
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=[tf.keras.metrics.SparseCategoricalAccuracy()],
# run_eagerly=True,
)
class BatchNormMomentumCallback(tf.keras.callbacks.Callback):
def on_epoch_begin(self, epoch, logs=None):
last_bn_layer = None
for layer in self.model.layers:
if isinstance(layer, tf.keras.layers.BatchNormalization):
if epoch == 0:
layer.momentum = 0.99
else:
layer.momentum = 1.0
last_bn_layer = layer
if last_bn_layer:
tf.print("Momentum=" + str(last_bn_layer.moving_mean[-1].numpy())) # Should not change after epoch 1
batchnorm_decay = BatchNormMomentumCallback()
model.fit(ds_train, epochs=6, validation_data=ds_test, callbacks=[batchnorm_decay], verbose=0)
输出(run_eagerly=False
时得到)
Momentum=0.0
Momentum=-102.20184
Momentum=-106.04614
Momentum=-116.36204
Momentum=-129.995
Momentum=-123.70443
预期输出(run_eagerly=True
时得到)
Momentum=0.0
Momentum=-5.9038606
Momentum=-5.9038606
Momentum=-5.9038606
Momentum=-5.9038606
Momentum=-5.9038606
我猜这是因为在图形模式下,TF 将模型编译为图形,动量定义为 0.99
,并在图形中使用此值(因此 momentum
未由 BatchNormMomentumCallback
).
问题:
有没有办法在训练时更新图中编译的 momentum
变量?我想更新 momentum
而不是急切模式(即使用 run_eagerly=False
),因为训练效率很重要。
我建议只针对您的用例使用自定义训练循环。您将拥有所需的所有灵活性:
import tensorflow as tf # version 2.3.1
import tensorflow_datasets as tfds
ds_train, ds_test = tfds.load("mnist", split=["train", "test"], shuffle_files=True, as_supervised=True)
ds_train = ds_train.batch(128)
ds_test = ds_test.batch(128)
model = tf.keras.models.Sequential(
[
tf.keras.layers.Flatten(input_shape=(28, 28)),
tf.keras.layers.Dense(128),
tf.keras.layers.BatchNormalization(),
tf.keras.layers.ReLU(),
tf.keras.layers.Dense(10),
]
)
optimizer = tf.keras.optimizers.Adam(0.001)
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
train_acc_metric = tf.keras.metrics.SparseCategoricalAccuracy()
batch_norm_layer = model.layers[2]
@tf.function
def train_step(epoch, model, batch):
if epoch == 0:
batch_norm_layer.momentum = 0.99
else:
batch_norm_layer.momentum = 1.0
with tf.GradientTape() as tape:
x_batch_train, y_batch_train = batch
logits = model(x_batch_train, training=True)
loss_value = loss_fn(y_batch_train, logits)
train_acc_metric.update_state(y_batch_train, logits)
grads = tape.gradient(loss_value, model.trainable_weights)
optimizer.apply_gradients(zip(grads, model.trainable_weights))
epochs = 6
for epoch in range(epochs):
tf.print("\nStart of epoch %d" % (epoch,))
tf.print("Momentum = ", batch_norm_layer.moving_mean[-1], summarize=-1)
for batch in ds_train:
train_step(epoch, model, batch)
train_acc = train_acc_metric.result()
tf.print("Training acc over epoch: %.4f" % (float(train_acc),))
train_acc_metric.reset_states()
Start of epoch 0
Momentum = 0
Training acc over epoch: 0.9158
Start of epoch 1
Momentum = -20.2749767
Training acc over epoch: 0.9634
Start of epoch 2
Momentum = -20.2749767
Training acc over epoch: 0.9755
Start of epoch 3
Momentum = -20.2749767
Training acc over epoch: 0.9826
Start of epoch 4
Momentum = -20.2749767
Training acc over epoch: 0.9876
Start of epoch 5
Momentum = -20.2749767
Training acc over epoch: 0.9915
一个简单的测试表明带有 tf.function
装饰器的函数执行得更好:
import tensorflow as tf # version 2.3.1
import tensorflow_datasets as tfds
import timeit
ds_train, ds_test = tfds.load("mnist", split=["train", "test"], shuffle_files=True, as_supervised=True)
ds_train = ds_train.batch(128)
ds_test = ds_test.batch(128)
model = tf.keras.models.Sequential(
[
tf.keras.layers.Flatten(input_shape=(28, 28)),
tf.keras.layers.Dense(128),
tf.keras.layers.BatchNormalization(),
tf.keras.layers.ReLU(),
tf.keras.layers.Dense(10),
]
)
optimizer = tf.keras.optimizers.Adam(0.001)
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
train_acc_metric = tf.keras.metrics.SparseCategoricalAccuracy()
batch_norm_layer = model.layers[2]
@tf.function
def train_step(epoch, model, batch):
if epoch == 0:
batch_norm_layer.momentum = 0.99
else:
batch_norm_layer.momentum = 1.0
with tf.GradientTape() as tape:
x_batch_train, y_batch_train = batch
logits = model(x_batch_train, training=True)
loss_value = loss_fn(y_batch_train, logits)
train_acc_metric.update_state(y_batch_train, logits)
grads = tape.gradient(loss_value, model.trainable_weights)
optimizer.apply_gradients(zip(grads, model.trainable_weights))
def train_step_without_tffunction(epoch, model, batch):
if epoch == 0:
batch_norm_layer.momentum = 0.99
else:
batch_norm_layer.momentum = 1.0
with tf.GradientTape() as tape:
x_batch_train, y_batch_train = batch
logits = model(x_batch_train, training=True)
loss_value = loss_fn(y_batch_train, logits)
train_acc_metric.update_state(y_batch_train, logits)
grads = tape.gradient(loss_value, model.trainable_weights)
optimizer.apply_gradients(zip(grads, model.trainable_weights))
epochs = 6
for epoch in range(epochs):
tf.print("\nStart of epoch %d" % (epoch,))
tf.print("Momentum = ", batch_norm_layer.moving_mean[-1], summarize=-1)
test = True
for batch in ds_train:
train_step(epoch, model, batch)
if test:
tf.print("TF function:", timeit.timeit(lambda: train_step(epoch, model, batch), number=10))
tf.print("Eager function:", timeit.timeit(lambda: train_step_without_tffunction(epoch, model, batch), number=10))
test = False
train_acc = train_acc_metric.result()
tf.print("Training acc over epoch: %.4f" % (float(train_acc),))
train_acc_metric.reset_states()
Start of epoch 0
Momentum = 0
TF function: 0.02285163299893611
Eager function: 0.11109527599910507
Training acc over epoch: 0.9229
Start of epoch 1
Momentum = -88.1852188
TF function: 0.024091466999379918
Eager function: 0.1109461480009486
Training acc over epoch: 0.9639
Start of epoch 2
Momentum = -88.1852188
TF function: 0.02331122400210006
Eager function: 0.11751473100230214
Training acc over epoch: 0.9756
Start of epoch 3
Momentum = -88.1852188
TF function: 0.02656845700039412
Eager function: 0.1121610670015798
Training acc over epoch: 0.9830
Start of epoch 4
Momentum = -88.1852188
TF function: 0.02821972700257902
Eager function: 0.15709391699783737
Training acc over epoch: 0.9877
Start of epoch 5
Momentum = -88.1852188
TF function: 0.02441513300072984
Eager function: 0.10921925399816246
Training acc over epoch: 0.9917
我希望批量标准化 运行 统计(均值和方差)在训练结束时收敛,这需要将批量标准化动量从某个初始值增加到 1.0。我设法使用自定义 Callback
来改变动量,但它仅在我的模型在急切模式下编译时才有效。玩具示例(它在纪元零之后设置 momentum=1.0
,因此 moving_mean
应该停止更新):
import tensorflow as tf # version 2.3.1
import tensorflow_datasets as tfds
ds_train, ds_test = tfds.load("mnist", split=["train", "test"], shuffle_files=True, as_supervised=True)
ds_train = ds_train.batch(128)
ds_test = ds_test.batch(128)
model = tf.keras.models.Sequential(
[
tf.keras.layers.Flatten(input_shape=(28, 28)),
tf.keras.layers.Dense(128),
tf.keras.layers.BatchNormalization(),
tf.keras.layers.ReLU(),
tf.keras.layers.Dense(10),
]
)
model.compile(
optimizer=tf.keras.optimizers.Adam(0.001),
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=[tf.keras.metrics.SparseCategoricalAccuracy()],
# run_eagerly=True,
)
class BatchNormMomentumCallback(tf.keras.callbacks.Callback):
def on_epoch_begin(self, epoch, logs=None):
last_bn_layer = None
for layer in self.model.layers:
if isinstance(layer, tf.keras.layers.BatchNormalization):
if epoch == 0:
layer.momentum = 0.99
else:
layer.momentum = 1.0
last_bn_layer = layer
if last_bn_layer:
tf.print("Momentum=" + str(last_bn_layer.moving_mean[-1].numpy())) # Should not change after epoch 1
batchnorm_decay = BatchNormMomentumCallback()
model.fit(ds_train, epochs=6, validation_data=ds_test, callbacks=[batchnorm_decay], verbose=0)
输出(run_eagerly=False
时得到)
Momentum=0.0
Momentum=-102.20184
Momentum=-106.04614
Momentum=-116.36204
Momentum=-129.995
Momentum=-123.70443
预期输出(run_eagerly=True
时得到)
Momentum=0.0
Momentum=-5.9038606
Momentum=-5.9038606
Momentum=-5.9038606
Momentum=-5.9038606
Momentum=-5.9038606
我猜这是因为在图形模式下,TF 将模型编译为图形,动量定义为 0.99
,并在图形中使用此值(因此 momentum
未由 BatchNormMomentumCallback
).
问题:
有没有办法在训练时更新图中编译的 momentum
变量?我想更新 momentum
而不是急切模式(即使用 run_eagerly=False
),因为训练效率很重要。
我建议只针对您的用例使用自定义训练循环。您将拥有所需的所有灵活性:
import tensorflow as tf # version 2.3.1
import tensorflow_datasets as tfds
ds_train, ds_test = tfds.load("mnist", split=["train", "test"], shuffle_files=True, as_supervised=True)
ds_train = ds_train.batch(128)
ds_test = ds_test.batch(128)
model = tf.keras.models.Sequential(
[
tf.keras.layers.Flatten(input_shape=(28, 28)),
tf.keras.layers.Dense(128),
tf.keras.layers.BatchNormalization(),
tf.keras.layers.ReLU(),
tf.keras.layers.Dense(10),
]
)
optimizer = tf.keras.optimizers.Adam(0.001)
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
train_acc_metric = tf.keras.metrics.SparseCategoricalAccuracy()
batch_norm_layer = model.layers[2]
@tf.function
def train_step(epoch, model, batch):
if epoch == 0:
batch_norm_layer.momentum = 0.99
else:
batch_norm_layer.momentum = 1.0
with tf.GradientTape() as tape:
x_batch_train, y_batch_train = batch
logits = model(x_batch_train, training=True)
loss_value = loss_fn(y_batch_train, logits)
train_acc_metric.update_state(y_batch_train, logits)
grads = tape.gradient(loss_value, model.trainable_weights)
optimizer.apply_gradients(zip(grads, model.trainable_weights))
epochs = 6
for epoch in range(epochs):
tf.print("\nStart of epoch %d" % (epoch,))
tf.print("Momentum = ", batch_norm_layer.moving_mean[-1], summarize=-1)
for batch in ds_train:
train_step(epoch, model, batch)
train_acc = train_acc_metric.result()
tf.print("Training acc over epoch: %.4f" % (float(train_acc),))
train_acc_metric.reset_states()
Start of epoch 0
Momentum = 0
Training acc over epoch: 0.9158
Start of epoch 1
Momentum = -20.2749767
Training acc over epoch: 0.9634
Start of epoch 2
Momentum = -20.2749767
Training acc over epoch: 0.9755
Start of epoch 3
Momentum = -20.2749767
Training acc over epoch: 0.9826
Start of epoch 4
Momentum = -20.2749767
Training acc over epoch: 0.9876
Start of epoch 5
Momentum = -20.2749767
Training acc over epoch: 0.9915
一个简单的测试表明带有 tf.function
装饰器的函数执行得更好:
import tensorflow as tf # version 2.3.1
import tensorflow_datasets as tfds
import timeit
ds_train, ds_test = tfds.load("mnist", split=["train", "test"], shuffle_files=True, as_supervised=True)
ds_train = ds_train.batch(128)
ds_test = ds_test.batch(128)
model = tf.keras.models.Sequential(
[
tf.keras.layers.Flatten(input_shape=(28, 28)),
tf.keras.layers.Dense(128),
tf.keras.layers.BatchNormalization(),
tf.keras.layers.ReLU(),
tf.keras.layers.Dense(10),
]
)
optimizer = tf.keras.optimizers.Adam(0.001)
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
train_acc_metric = tf.keras.metrics.SparseCategoricalAccuracy()
batch_norm_layer = model.layers[2]
@tf.function
def train_step(epoch, model, batch):
if epoch == 0:
batch_norm_layer.momentum = 0.99
else:
batch_norm_layer.momentum = 1.0
with tf.GradientTape() as tape:
x_batch_train, y_batch_train = batch
logits = model(x_batch_train, training=True)
loss_value = loss_fn(y_batch_train, logits)
train_acc_metric.update_state(y_batch_train, logits)
grads = tape.gradient(loss_value, model.trainable_weights)
optimizer.apply_gradients(zip(grads, model.trainable_weights))
def train_step_without_tffunction(epoch, model, batch):
if epoch == 0:
batch_norm_layer.momentum = 0.99
else:
batch_norm_layer.momentum = 1.0
with tf.GradientTape() as tape:
x_batch_train, y_batch_train = batch
logits = model(x_batch_train, training=True)
loss_value = loss_fn(y_batch_train, logits)
train_acc_metric.update_state(y_batch_train, logits)
grads = tape.gradient(loss_value, model.trainable_weights)
optimizer.apply_gradients(zip(grads, model.trainable_weights))
epochs = 6
for epoch in range(epochs):
tf.print("\nStart of epoch %d" % (epoch,))
tf.print("Momentum = ", batch_norm_layer.moving_mean[-1], summarize=-1)
test = True
for batch in ds_train:
train_step(epoch, model, batch)
if test:
tf.print("TF function:", timeit.timeit(lambda: train_step(epoch, model, batch), number=10))
tf.print("Eager function:", timeit.timeit(lambda: train_step_without_tffunction(epoch, model, batch), number=10))
test = False
train_acc = train_acc_metric.result()
tf.print("Training acc over epoch: %.4f" % (float(train_acc),))
train_acc_metric.reset_states()
Start of epoch 0
Momentum = 0
TF function: 0.02285163299893611
Eager function: 0.11109527599910507
Training acc over epoch: 0.9229
Start of epoch 1
Momentum = -88.1852188
TF function: 0.024091466999379918
Eager function: 0.1109461480009486
Training acc over epoch: 0.9639
Start of epoch 2
Momentum = -88.1852188
TF function: 0.02331122400210006
Eager function: 0.11751473100230214
Training acc over epoch: 0.9756
Start of epoch 3
Momentum = -88.1852188
TF function: 0.02656845700039412
Eager function: 0.1121610670015798
Training acc over epoch: 0.9830
Start of epoch 4
Momentum = -88.1852188
TF function: 0.02821972700257902
Eager function: 0.15709391699783737
Training acc over epoch: 0.9877
Start of epoch 5
Momentum = -88.1852188
TF function: 0.02441513300072984
Eager function: 0.10921925399816246
Training acc over epoch: 0.9917