Pytorch BatchNorm2d RuntimeError: running_mean should contain 64 elements not 0

Pytorch BatchNorm2d RuntimeError: running_mean should contain 64 elements not 0

我正在使用 Octave Convolutions 并设置了一个 BatchNorm2d 适应,对于某些原因,我正在使用它

RuntimeError: running_mean should contain 64 elements not 0

我已经设置了一些调试打印来检查我的 Tensors 的尺寸有什么问题,但无法找到它。 这是我的 class:

class _BatchNorm2d(nn.Module):
  def __init__(self, num_features, alpha_in=0, alpha_out=0, eps=1e-5, momentum=0.1, affine=True,
               track_running_stats=True):
    super(_BatchNorm2d, self).__init__()
    hf_ch = int(num_features * (1 - alpha_out))
    lf_ch = num_features - hf_ch
    self.bnh = nn.BatchNorm2d(hf_ch)
    self.bnl = nn.BatchNorm2d(lf_ch)
  def forward(self, x):
    if isinstance(x, tuple):
        hf, lf = x
        print("IN ON BN: ",lf.shape if lf is not None else None) #DEBUGGING PRINT
        print(self.bnl)  #DEBUGGING PRINT
        hf = self.bnh(hf) if type(hf) == torch.Tensor else hf
        lf = self.bnh(lf) if type(lf) == torch.Tensor else lf #THIS IS THE LINE ACCUSING THE ERROR
        print("ENDED BN")
        return hf, lf
    else:
        return self.bnh(x)

这里是打印错误:

IN ON BN:  torch.Size([32, 64, 3, 3])
BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)

在我看来该函数应该有效,因为 x 有 64 个通道,而 bn 需要 64 个通道。

编辑: 提到错误仅发生在 alpha 值为 1 时可能也很重要。但是,我不明白,因为体积仍然相同。

已解决。这是低频 BN 的调用错误。

    hf = self.bnh(hf) if type(hf) == torch.Tensor else hf
    lf = self.bnh(lf) if type(lf) == torch.Tensor else lf

应该是

    hf = self.bnh(hf) if type(hf) == torch.Tensor else hf
    lf = self.bnl(lf) if type(lf) == torch.Tensor else lf