Pytorch 计算堆叠张量的单独损失

Pytorch compute separate losses of stacked tensors

我有一个 (N, k, 1) 张量,它来自 k 网络的堆叠标量预测。 N 是批量大小。

所有预测的目标都是相同的,t。如何有效地计算损失(例如 MSE)?我现在正在做的是拆分每个网络的预测并对单独的损失求和。


stacked_predictions  # (N, k, 1) tensor with the predictions
t  # common target

predictions = [prediction[:, i] for i in range(stacked_predictions.size()[1])]
loss = sum(self.loss(prediciton, t) for prediction in predictions)

optimizer.zero_grad()
loss.backward()
optimizer.step()

是否有更有效的方法来实现同样的目标?

是的,谢天谢地,您可以在 python 中通过广播轻松做到这一点: 假设:

N=100 #For example
k=10 #For example
stacked_predictions = torch.randn(N, k, 1)  # (N, k, 1) tensor with the predictions
t = torch.randn(N,1)  # common target

然后,您可以获得 loss:

的有效等效计算
loss = k * nn.MSELoss()(stacked_predictions, t[:, None, :])

(将 nn.MSELoss() 替换为 self.loss)。 请注意,t[:, None, :] 在中间向 t 添加了另一个单例维度,因此 t 变为 (N,1,1) 形状,而 stacked_predictions 变为 (N,k,1) 形状。当您在这 2 个张量上调用 nn.MSELoss() 时,t 的深处将被广播以匹配 stacked_predictions 的形状,这就是您无需自己重复 t 的原因。

请注意,我添加了 k 的乘法,这是因为您想要的是总和,而不是维度 1(大小为 k 的那个)的平均值。如果您想要所有预测的平均值,请省略乘法。