如何在忽略 class 的情况下使用 pytorch 闪电精度?
How to use pytorch lightning Accuracy with ignore class?
我有一些训练管道使用 CrossEntropyLoss
忽略 class。
模型输出 log_probs
形状 (150, 3)
- 这意味着 3 种可能的 class 以 150 为一批。
label_batch
的形状是 150
,torch.max(label_batch)
== tensor(3, device='cuda:0')
,这意味着有一个额外的 class 标记为 3
, 即忽略 class.
损失处理得很好:
self._criterion = nn.CrossEntropyLoss(
reduction='mean',
ignore_index=3
)
但是准确度指标认为 class 3
是有效的并且给出了非常错误的结果:
self.train_acc = pl.metrics.Accuracy()
错误的结果 self.train_acc.update(log_probs, label_batch)
因为 3
标签应该被忽略。
如何正确使用 pl.metrics.Accuracy()
忽略 class?
正在复制 github 论坛中讨论帖的回复 https://github.com/PyTorchLightning/pytorch-lightning/discussions/6890
它目前在准确性指标中不受支持,但我们有一个开放的 PR 来实现这个确切的功能 PyTorchLightning/metrics#155
目前你能做的是计算混淆矩阵,然后忽略一些基于此的 类(记住真正的 positive/correctly 分类是在混淆矩阵的对角线上找到的):
ignore_index = 3
metric = ConfusionMatrix(num_classes=3)
confmat = metric(preds, target)
confmat = confmat[:2,:2] # remove last column and row corresponding to class 3
acc = confmat.trace() / confmat.sum()
我有一些训练管道使用 CrossEntropyLoss
忽略 class。
模型输出 log_probs
形状 (150, 3)
- 这意味着 3 种可能的 class 以 150 为一批。
label_batch
的形状是 150
,torch.max(label_batch)
== tensor(3, device='cuda:0')
,这意味着有一个额外的 class 标记为 3
, 即忽略 class.
损失处理得很好:
self._criterion = nn.CrossEntropyLoss(
reduction='mean',
ignore_index=3
)
但是准确度指标认为 class 3
是有效的并且给出了非常错误的结果:
self.train_acc = pl.metrics.Accuracy()
错误的结果 self.train_acc.update(log_probs, label_batch)
因为 3
标签应该被忽略。
如何正确使用 pl.metrics.Accuracy()
忽略 class?
正在复制 github 论坛中讨论帖的回复 https://github.com/PyTorchLightning/pytorch-lightning/discussions/6890
它目前在准确性指标中不受支持,但我们有一个开放的 PR 来实现这个确切的功能 PyTorchLightning/metrics#155
目前你能做的是计算混淆矩阵,然后忽略一些基于此的 类(记住真正的 positive/correctly 分类是在混淆矩阵的对角线上找到的):
ignore_index = 3
metric = ConfusionMatrix(num_classes=3)
confmat = metric(preds, target)
confmat = confmat[:2,:2] # remove last column and row corresponding to class 3
acc = confmat.trace() / confmat.sum()