有什么方法可以在pytorch的损失函数中包含一个计数器(一个可以计数的变量)?
is there any way to include a counter(a variable that count something) in a loss function in pytorch?
这些是我的损失函数中的一些行。 output
是多类分类网络的输出。
bin_count=torch.bincount(torch.where(output>.1)[0], minlength=output.shape[0])
dr_output = (bin_count == 1) & (torch.argmax(output, dim=1)==labels)
我希望 dr_output.sum()
成为我的损失函数的一部分。但是我的实现有很多限制。有些函数在 pytorch 中是不可微分的,而且 dr_output
也可能为零,如果我只使用 dr_output
作为我的损失,这也是不允许的。任何人都可以向我建议解决这些问题的方法吗?
如果我没记错的话:
bin_count=torch.bincount(torch.where(output>.1)[0], minlength=output.shape[0])
计算每行有多少元素大于 .1
。
改为:
dr_output = (bin_count == 1) & (torch.argmax(output, dim=1)==labels)
对应行只有大于.1
的元素为真,预测正确
dr_output.sum()
然后计算有多少行验证此条件,因此最小化损失可能会强制执行不正确的预测或分布值大于 .1
.
考虑到这些因素,您可以通过以下方式估算您的损失:
import torch.nn.functional as F
# x are the inputs, y the labels
mask = x > 0.1
p = F.softmax(x, dim=1)
out = p * (mask.sum(dim=1, keepdim=True) == 1)
loss = out[torch.arange(x.shape[0]), y].sum()
您可以设计更适合您的问题的类似变体。
这些是我的损失函数中的一些行。 output
是多类分类网络的输出。
bin_count=torch.bincount(torch.where(output>.1)[0], minlength=output.shape[0])
dr_output = (bin_count == 1) & (torch.argmax(output, dim=1)==labels)
我希望 dr_output.sum()
成为我的损失函数的一部分。但是我的实现有很多限制。有些函数在 pytorch 中是不可微分的,而且 dr_output
也可能为零,如果我只使用 dr_output
作为我的损失,这也是不允许的。任何人都可以向我建议解决这些问题的方法吗?
如果我没记错的话:
bin_count=torch.bincount(torch.where(output>.1)[0], minlength=output.shape[0])
计算每行有多少元素大于 .1
。
改为:
dr_output = (bin_count == 1) & (torch.argmax(output, dim=1)==labels)
对应行只有大于.1
的元素为真,预测正确
dr_output.sum()
然后计算有多少行验证此条件,因此最小化损失可能会强制执行不正确的预测或分布值大于 .1
.
考虑到这些因素,您可以通过以下方式估算您的损失:
import torch.nn.functional as F
# x are the inputs, y the labels
mask = x > 0.1
p = F.softmax(x, dim=1)
out = p * (mask.sum(dim=1, keepdim=True) == 1)
loss = out[torch.arange(x.shape[0]), y].sum()
您可以设计更适合您的问题的类似变体。