有什么方法可以在pytorch的损失函数中包含一个计数器(一个可以计数的变量)?

is there any way to include a counter(a variable that count something) in a loss function in pytorch?

这些是我的损失函数中的一些行。 output 是多类分类网络的输出。

bin_count=torch.bincount(torch.where(output>.1)[0], minlength=output.shape[0])

dr_output = (bin_count == 1) & (torch.argmax(output, dim=1)==labels)

我希望 dr_output.sum() 成为我的损失函数的一部分。但是我的实现有很多限制。有些函数在 pytorch 中是不可微分的,而且 dr_output 也可能为零,如果我只使用 dr_output 作为我的损失,这也是不允许的。任何人都可以向我建议解决这些问题的方法吗?

如果我没记错的话:

bin_count=torch.bincount(torch.where(output>.1)[0], minlength=output.shape[0])

计算每行有多少元素大于 .1

改为:

dr_output = (bin_count == 1) & (torch.argmax(output, dim=1)==labels)

对应行只有大于.1的元素为真,预测正确

dr_output.sum() 然后计算有多少行验证此条件,因此最小化损失可能会强制执行不正确的预测或分布值大于 .1.

考虑到这些因素,您可以通过以下方式估算您的损失:

import torch.nn.functional as F

# x are the inputs, y the labels

mask = x > 0.1
p = F.softmax(x, dim=1)
out = p * (mask.sum(dim=1, keepdim=True) == 1)

loss = out[torch.arange(x.shape[0]), y].sum()

您可以设计更适合您的问题的类似变体。