Keras BatchNormalization 仅适用于 axis=0 时的常量批暗淡?
Keras BatchNormalization only works for constant batch dim when axis=0?
以下代码显示了一种可行的方法和另一种失败的方法。
axis=0 上的 BatchNorm 不应依赖于批量大小,或者如果它依赖于批量大小,则应在文档中明确说明。
In [118]: tf.__version__
Out[118]: '2.0.0-beta1'
class M(tf.keras.models.Model):
import numpy as np
import tensorflow as tf
class M(tf.keras.Model):
def __init__(self, axis):
super().__init__()
self.layer = tf.keras.layers.BatchNormalization(axis=axis, scale=False, center=True, input_shape=(6,))
def call(self, x):
out = self.layer(x)
return out
def fails():
m = M(axis=0)
x = np.random.randn(3, 6).astype(np.float32)
print(m(x))
x = np.random.randn(2, 6).astype(np.float32)
print(m(x))
def ok():
m = M(axis=1)
x = np.random.randn(3, 6).astype(np.float32)
print(m(x))
x = np.random.randn(2, 6).astype(np.float32)
print(m(x))
编辑:
args 中的轴不是您认为的轴。
正如 and the Keras doc中所述,axis
参数表示特征轴。这完全是有道理的,因为我们想要进行特征归一化,即对整个输入批次的每个特征进行归一化(这与我们可能对图像进行的特征归一化一致,例如从所有图像中减去 "mean pixel"数据集的图像)。
现在,您编写的 fails()
方法在这一行失败:
x = np.random.randn(2, 6).astype(np.float32)
print(m(x))
这是因为您在构建模型时将特征轴设置为 0,即第一个轴,因此在上述代码之前执行以下行时:
x = np.random.randn(3, 6).astype(np.float32)
print(m(x))
层的权重将基于 3 个特征构建(不要忘记您已将特征轴指示为 0,因此在形状输入中将有 3 个特征(3,6)
)。因此,当你给它一个形状为 (2,6)
的输入张量时,它会正确地引发错误,因为该张量中有 2 个特征,因此由于这种不匹配而无法进行归一化。
另一方面,ok()
方法有效,因为特征轴是最后一个轴,因此两个输入张量具有相同数量的特征,即 6。因此在两种情况下都可以对所有的特征进行标准化功能。
以下代码显示了一种可行的方法和另一种失败的方法。
axis=0 上的 BatchNorm 不应依赖于批量大小,或者如果它依赖于批量大小,则应在文档中明确说明。
In [118]: tf.__version__
Out[118]: '2.0.0-beta1'
class M(tf.keras.models.Model):
import numpy as np
import tensorflow as tf
class M(tf.keras.Model):
def __init__(self, axis):
super().__init__()
self.layer = tf.keras.layers.BatchNormalization(axis=axis, scale=False, center=True, input_shape=(6,))
def call(self, x):
out = self.layer(x)
return out
def fails():
m = M(axis=0)
x = np.random.randn(3, 6).astype(np.float32)
print(m(x))
x = np.random.randn(2, 6).astype(np.float32)
print(m(x))
def ok():
m = M(axis=1)
x = np.random.randn(3, 6).astype(np.float32)
print(m(x))
x = np.random.randn(2, 6).astype(np.float32)
print(m(x))
编辑:
args 中的轴不是您认为的轴。
正如axis
参数表示特征轴。这完全是有道理的,因为我们想要进行特征归一化,即对整个输入批次的每个特征进行归一化(这与我们可能对图像进行的特征归一化一致,例如从所有图像中减去 "mean pixel"数据集的图像)。
现在,您编写的 fails()
方法在这一行失败:
x = np.random.randn(2, 6).astype(np.float32)
print(m(x))
这是因为您在构建模型时将特征轴设置为 0,即第一个轴,因此在上述代码之前执行以下行时:
x = np.random.randn(3, 6).astype(np.float32)
print(m(x))
层的权重将基于 3 个特征构建(不要忘记您已将特征轴指示为 0,因此在形状输入中将有 3 个特征(3,6)
)。因此,当你给它一个形状为 (2,6)
的输入张量时,它会正确地引发错误,因为该张量中有 2 个特征,因此由于这种不匹配而无法进行归一化。
另一方面,ok()
方法有效,因为特征轴是最后一个轴,因此两个输入张量具有相同数量的特征,即 6。因此在两种情况下都可以对所有的特征进行标准化功能。