在空间轴上计算 Pytorch 中的方差

Calculating Variance in Pytorch on Spatial axis

我正在尝试在 Pytorch 中计算方差,但无法在多轴上计算。

我在 Tensorflow 中完成了类似的事情,但无法在 Pytorch 上完成,因为 torch.var 函数将 int 作为维度而不是轴。 下面的代码是通道最后的代码,我希望轴=[2,3]

Lambda(lambda x: tf.nn.moments(x, axes=[1, 2]))

例如,如果 input_dims = (5, 10, 25, 25) 那么 output_dims 应该是 (5,10, 1, 1)。

您可以做的一件事是在应用 var() 方法之前,使用 tensor.view() 将您想要计算方差的所有维度展平为一个维度:

torch.var(x.view(x.shape[0], x.shape[1], 1, -1,), dim=3, keepdim=True)

我使用 keepdim=True 来保持我们计算方差的维度以获得所需的输出形状。