Pytorch nn.functional.batch_norm 用于 2D 输入

Pytorch nn.functional.batch_norm for 2D input

我目前正在实施一个模型,我需要在测试期间更改 运行 均值和标准差。因此,我认为 nn.functional.batch_norm 是比 nn.BatchNorm2d

更好的选择

但是,我有一批图像作为输入,目前我不确定如何接收这些图像。我如何将 nn.functional.batch_norm 应用到成批的 2D 图像上?

我目前的代码是这样的,我post虽然这是不正确的:

mu = torch.mean(inp[0])
stddev = torch.std(inp[0])
x = nn.functional.batch_norm(inp[0], mu, stddev, training=True, momentum=0.9)

关键是2D batchnorm对每个通道进行相同的归一化。也就是说,如果你有一批形状为 (N, C, H, W) 的数据,那么你的 mu 和 stddev 应该是形状 (C,)。如果您的图像没有通道尺寸,请使用 view.

添加一个

警告: 如果您设置 training=True,则 batch_norm 计算并使用有争议的批次的适当规范化统计信息(这意味着我们不需要自己计算均值和标准差)。您争论的 mu 和 stddev 应该是所有训练批次的 运行 均值和 运行 标准差。这些张量在 batch_norm 函数中使用新的批处理统计信息进行更新。

# inp is shape (N, C, H, W)
n_chans = inp.shape[1]
running_mu = torch.zeros(n_chans) # zeros are fine for first training iter
running_std = torch.ones(n_chans) # ones are fine for first training iter
x = nn.functional.batch_norm(inp, running_mu, running_std, training=True, momentum=0.9)
# running_mu and running_std now have new values

如果您只想使用自己的批处理统计信息,试试这个:

# inp is shape (N, C, H, W)
n_chans = inp.shape[1]
reshaped_inp = inp.permute(1,0,2,3).contiguous().view(n_chans, -1) # shape (C, N*W*H)
mu = reshaped_inp.mean(-1)
stddev = reshaped_inp.std(-1)
x = nn.functional.batch_norm(inp, mu, stddev, training=False)