标签不平衡的多标签分类
Multi label classification with unbalanced labels
我正在构建多标签分类网络。
我的 GT 是长度为 512
[0,0,0,1,0,1,0,...,0,0,0,1]
的向量
大多数时候它们是zeroes
,每个向量大约有5 ones
,其余的都是零。
我想做:
使用sigmoid
激活输出层。
损失函数使用binary_crossentropy
。
但是我该如何解决不平衡问题呢?
网络可以学习预测 always zeros
并且学习损失分数仍然很低。
我怎样才能让它真正学会预测...
您不能轻易地进行上采样,因为这是一个多标签案例(我最初从 post 中遗漏的内容)。
你可以做的是给1
更高的权重,像这样:
import torch
class BCEWithLogitsLossWeighted(torch.nn.Module):
def __init__(self, weight, *args, **kwargs):
super().__init__()
# Notice none reduction
self.bce = torch.nn.BCEWithLogitsLoss(*args, **kwargs, reduction="none")
self.weight = weight
def forward(self, logits, labels):
loss = self.bce(logits, labels)
binary_labels = labels.bool()
loss[binary_labels] *= labels[binary_labels] * self.weight
# Or any other reduction
return torch.mean(loss)
loss = BCEWithLogitsLossWeighted(50)
logits = torch.randn(64, 512)
labels = torch.randint(0, 2, size=(64, 512)).float()
print(loss(logits, labels))
您也可以使用 FocalLoss 关注正面示例(某些库中应该有一些实现)。
编辑:
Focal Loss 也可以按照这些思路进行编码(函数形式,因为这就是我在 repo 中的内容,但你应该能够从中开始工作):
def binary_focal_loss(
outputs: torch.Tensor,
targets: torch.Tensor,
gamma: float,
weight=None,
pos_weight=None,
reduction: typing.Callable[[torch.Tensor], torch.Tensor] = None,
) -> torch.Tensor:
probabilities = (1 - torch.sigmoid(outputs)) ** gamma
loss = probabilities * torch.nn.functional.binary_cross_entropy_with_logits(
outputs,
targets.float(),
weight,
reduction="none",
pos_weight=pos_weight,
)
return reduction(loss)
我正在构建多标签分类网络。
我的 GT 是长度为 512
[0,0,0,1,0,1,0,...,0,0,0,1]
的向量
大多数时候它们是zeroes
,每个向量大约有5 ones
,其余的都是零。
我想做:
使用sigmoid
激活输出层。
损失函数使用binary_crossentropy
。
但是我该如何解决不平衡问题呢?
网络可以学习预测 always zeros
并且学习损失分数仍然很低。
我怎样才能让它真正学会预测...
您不能轻易地进行上采样,因为这是一个多标签案例(我最初从 post 中遗漏的内容)。
你可以做的是给1
更高的权重,像这样:
import torch
class BCEWithLogitsLossWeighted(torch.nn.Module):
def __init__(self, weight, *args, **kwargs):
super().__init__()
# Notice none reduction
self.bce = torch.nn.BCEWithLogitsLoss(*args, **kwargs, reduction="none")
self.weight = weight
def forward(self, logits, labels):
loss = self.bce(logits, labels)
binary_labels = labels.bool()
loss[binary_labels] *= labels[binary_labels] * self.weight
# Or any other reduction
return torch.mean(loss)
loss = BCEWithLogitsLossWeighted(50)
logits = torch.randn(64, 512)
labels = torch.randint(0, 2, size=(64, 512)).float()
print(loss(logits, labels))
您也可以使用 FocalLoss 关注正面示例(某些库中应该有一些实现)。
编辑:
Focal Loss 也可以按照这些思路进行编码(函数形式,因为这就是我在 repo 中的内容,但你应该能够从中开始工作):
def binary_focal_loss(
outputs: torch.Tensor,
targets: torch.Tensor,
gamma: float,
weight=None,
pos_weight=None,
reduction: typing.Callable[[torch.Tensor], torch.Tensor] = None,
) -> torch.Tensor:
probabilities = (1 - torch.sigmoid(outputs)) ** gamma
loss = probabilities * torch.nn.functional.binary_cross_entropy_with_logits(
outputs,
targets.float(),
weight,
reduction="none",
pos_weight=pos_weight,
)
return reduction(loss)