Keras BatchNormalization 仅适用于 axis=0 时的常量批暗淡?

Keras BatchNormalization only works for constant batch dim when axis=0?

以下代码显示了一种可行的方法和另一种失败的方法。

axis=0 上的 BatchNorm 不应依赖于批量大小,或者如果它依赖于批量大小,则应在文档中明确说明。

In [118]: tf.__version__
Out[118]: '2.0.0-beta1'



class M(tf.keras.models.Model):
import numpy as np
import tensorflow as tf

class M(tf.keras.Model):

    def __init__(self, axis):
        super().__init__()
        self.layer = tf.keras.layers.BatchNormalization(axis=axis, scale=False, center=True, input_shape=(6,))

    def call(self, x):
        out = self.layer(x)
        return out

def fails():
    m = M(axis=0)
    x = np.random.randn(3, 6).astype(np.float32)
    print(m(x))
    x = np.random.randn(2, 6).astype(np.float32)
    print(m(x))

def ok():
    m = M(axis=1)
    x = np.random.randn(3, 6).astype(np.float32)
    print(m(x))
    x = np.random.randn(2, 6).astype(np.float32)
    print(m(x))

编辑:

args 中的轴不是您认为的轴。

正如 and the Keras doc中所述,axis参数表示特征轴。这完全是有道理的,因为我们想要进行特征归一化,即对整个输入批次的每个特征进行归一化(这与我们可能对图像进行的特征归一化一致,例如从所有图像中减去 "mean pixel"数据集的图像)。

现在,您编写的 fails() 方法在这一行失败:

x = np.random.randn(2, 6).astype(np.float32)
print(m(x))

这是因为您在构建模型时将特征轴设置为 0,即第一个轴,因此在上述代码之前执行以下行时:

x = np.random.randn(3, 6).astype(np.float32)
print(m(x))

层的权重将基于 3 个特征构建(不要忘记您已将特征轴指示为 0,因此在形状输入中将有 3 个特征(3,6))。因此,当你给它一个形状为 (2,6) 的输入张量时,它会正确地引发错误,因为该张量中有 2 个特征,因此由于这种不匹配而无法进行归一化。

另一方面,ok() 方法有效,因为特征轴是最后一个轴,因此两个输入张量具有相同数量的特征,即 6。因此在两种情况下都可以对所有的特征进行标准化功能。