如何在忽略 class 的情况下使用 pytorch 闪电精度?

How to use pytorch lightning Accuracy with ignore class?

我有一些训练管道使用 CrossEntropyLoss 忽略 class。

模型输出 log_probs 形状 (150, 3) - 这意味着 3 种可能的 class 以 150 为一批。

label_batch 的形状是 150torch.max(label_batch) == tensor(3, device='cuda:0'),这意味着有一个额外的 class 标记为 3 , 即忽略 class.

损失处理得很好:

self._criterion = nn.CrossEntropyLoss(
    reduction='mean',
    ignore_index=3
)

但是准确度指标认为 class 3 是有效的并且给出了非常错误的结果:

self.train_acc = pl.metrics.Accuracy()

错误的结果 self.train_acc.update(log_probs, label_batch) 因为 3 标签应该被忽略。


如何正确使用 pl.metrics.Accuracy() 忽略 class?

正在复制 github 论坛中讨论帖的回复 https://github.com/PyTorchLightning/pytorch-lightning/discussions/6890


它目前在准确性指标中不受支持,但我们有一个开放的 PR 来实现这个确切的功能 PyTorchLightning/metrics#155

目前你能做的是计算混淆矩阵,然后忽略一些基于此的 类(记住真正的 positive/correctly 分类是在混淆矩阵的对角线上找到的):

ignore_index = 3
metric = ConfusionMatrix(num_classes=3)
confmat = metric(preds, target)
confmat = confmat[:2,:2] # remove last column and row corresponding to class 3
acc = confmat.trace() / confmat.sum()