pos_weight在二元交叉熵计算中
pos_weight in binary cross entropy calculation
当我们处理不平衡的训练数据时(负样本多,正样本少),通常会用到pos_weight
参数。
pos_weight
的期望是,当 positive sample
得到错误的标签时,模型会比 negative sample
得到更高的损失。
当我使用binary_cross_entropy_with_logits
函数时,我发现:
bce = torch.nn.functional.binary_cross_entropy_with_logits
pos_weight = torch.FloatTensor([5])
preds_pos_wrong = torch.FloatTensor([0.5, 1.5])
label_pos = torch.FloatTensor([1, 0])
loss_pos_wrong = bce(preds_pos_wrong, label_pos, pos_weight=pos_weight)
preds_neg_wrong = torch.FloatTensor([1.5, 0.5])
label_neg = torch.FloatTensor([0, 1])
loss_neg_wrong = bce(preds_neg_wrong, label_neg, pos_weight=pos_weight)
但是:
>>> loss_pos_wrong
tensor(2.0359)
>>> loss_neg_wrong
tensor(2.0359)
错误的正样本和负样本得到的损失是一样的,那么pos_weight
如何计算不平衡数据损失?
TLDR;两个损失是相同的,因为您正在计算相同的数量:两个输入是相同的,两个批处理元素和标签只是交换了。
为什么你会得到同样的损失?
我认为您对 F.binary_cross_entropy_with_logits
(you can find a more detailed documentation page with nn.BCEWithLogitsLoss
) 的用法感到困惑。在您的情况下,您的输入形状(aka 模型的输出)是一维的,这意味着您只有一个 logit x
, 而不是两个).
在你的例子中你有
preds_pos_wrong = torch.FloatTensor([0.5, 1.5])
label_pos = torch.FloatTensor([1, 0])
这意味着您的批量大小为 2
,并且由于默认情况下该函数正在平均批量元素的损失,因此您最终会得到相同的结果 BCE(preds_pos_wrong, label_pos)
和 BCE(preds_neg_wrong, label_neg)
.您的批处理的两个元素刚刚切换。
您可以通过不使用 reduction='none'
选项平均批量元素的损失来非常容易地验证这一点:
>>> F.binary_cross_entropy_with_logits(preds_pos_wrong, label_pos,
pos_weight=pos_weight, reduction='none')
tensor([2.3704, 1.7014])
>>> F.binary_cross_entropy_with_logits(preds_pos_wrong, label_pos,
pos_weight=pos_weight, reduction='none')
tensor([1.7014, 2.3704])
调查F.binary_cross_entropy_with_logits
:
也就是说二元交叉熵的公式是:
bce = -[y*log(sigmoid(x)) + (1-y)*log(1- sigmoid(x))]
其中 y
(分别 sigmoid(x)
是与该 logit 关联的正数 class,而 1 - y
(分别是 1 - sigmoid(x)
)是负数class.
文档在 pos_weight
的加权方案上可能更精确(不要与 weight
混淆,后者是不同 logits 输出的加权)。如您所说,pos_weight
的想法是权衡正项,而不是整个项。
bce = -[w_p*y*log(sigmoid(x)) + (1-y)*log(1- sigmoid(x))]
其中 w_p
是正项的权重,以补偿正负样本的不平衡。实际上,这应该是 w_p = #positive/#negative
.
因此:
>>> w_p = torch.FloatTensor([5])
>>> preds = torch.FloatTensor([0.5, 1.5])
>>> label = torch.FloatTensor([1, 0])
使用内置损失函数,
>>> F.binary_cross_entropy_with_logits(preds, label, pos_weight=w_p, reduction='none')
tensor([2.3704, 1.7014])
与人工计算相比:
>>> z = torch.sigmoid(preds)
>>> -(w_p*label*torch.log(z) + (1-label)*torch.log(1-z))
tensor([2.3704, 1.7014])
当我们处理不平衡的训练数据时(负样本多,正样本少),通常会用到pos_weight
参数。
pos_weight
的期望是,当 positive sample
得到错误的标签时,模型会比 negative sample
得到更高的损失。
当我使用binary_cross_entropy_with_logits
函数时,我发现:
bce = torch.nn.functional.binary_cross_entropy_with_logits
pos_weight = torch.FloatTensor([5])
preds_pos_wrong = torch.FloatTensor([0.5, 1.5])
label_pos = torch.FloatTensor([1, 0])
loss_pos_wrong = bce(preds_pos_wrong, label_pos, pos_weight=pos_weight)
preds_neg_wrong = torch.FloatTensor([1.5, 0.5])
label_neg = torch.FloatTensor([0, 1])
loss_neg_wrong = bce(preds_neg_wrong, label_neg, pos_weight=pos_weight)
但是:
>>> loss_pos_wrong
tensor(2.0359)
>>> loss_neg_wrong
tensor(2.0359)
错误的正样本和负样本得到的损失是一样的,那么pos_weight
如何计算不平衡数据损失?
TLDR;两个损失是相同的,因为您正在计算相同的数量:两个输入是相同的,两个批处理元素和标签只是交换了。
为什么你会得到同样的损失?
我认为您对 F.binary_cross_entropy_with_logits
(you can find a more detailed documentation page with nn.BCEWithLogitsLoss
) 的用法感到困惑。在您的情况下,您的输入形状(aka 模型的输出)是一维的,这意味着您只有一个 logit x
, 而不是两个).
在你的例子中你有
preds_pos_wrong = torch.FloatTensor([0.5, 1.5])
label_pos = torch.FloatTensor([1, 0])
这意味着您的批量大小为 2
,并且由于默认情况下该函数正在平均批量元素的损失,因此您最终会得到相同的结果 BCE(preds_pos_wrong, label_pos)
和 BCE(preds_neg_wrong, label_neg)
.您的批处理的两个元素刚刚切换。
您可以通过不使用 reduction='none'
选项平均批量元素的损失来非常容易地验证这一点:
>>> F.binary_cross_entropy_with_logits(preds_pos_wrong, label_pos,
pos_weight=pos_weight, reduction='none')
tensor([2.3704, 1.7014])
>>> F.binary_cross_entropy_with_logits(preds_pos_wrong, label_pos,
pos_weight=pos_weight, reduction='none')
tensor([1.7014, 2.3704])
调查F.binary_cross_entropy_with_logits
:
也就是说二元交叉熵的公式是:
bce = -[y*log(sigmoid(x)) + (1-y)*log(1- sigmoid(x))]
其中 y
(分别 sigmoid(x)
是与该 logit 关联的正数 class,而 1 - y
(分别是 1 - sigmoid(x)
)是负数class.
文档在 pos_weight
的加权方案上可能更精确(不要与 weight
混淆,后者是不同 logits 输出的加权)。如您所说,pos_weight
的想法是权衡正项,而不是整个项。
bce = -[w_p*y*log(sigmoid(x)) + (1-y)*log(1- sigmoid(x))]
其中 w_p
是正项的权重,以补偿正负样本的不平衡。实际上,这应该是 w_p = #positive/#negative
.
因此:
>>> w_p = torch.FloatTensor([5])
>>> preds = torch.FloatTensor([0.5, 1.5])
>>> label = torch.FloatTensor([1, 0])
使用内置损失函数,
>>> F.binary_cross_entropy_with_logits(preds, label, pos_weight=w_p, reduction='none')
tensor([2.3704, 1.7014])
与人工计算相比:
>>> z = torch.sigmoid(preds)
>>> -(w_p*label*torch.log(z) + (1-label)*torch.log(1-z))
tensor([2.3704, 1.7014])