'torchmetrics' 不适用于 PyTorchLightning

'torchmetrics' does not work with PyTorchLightning

我正在尝试了解如何将 torchmetrics 与 PyTorch Lightning 一起使用。
但是,我得到了与准确度、F1 分数、精度等相同的输出

这是代码。

metric_acc = torchmetrics.Accuracy()
metric_f1 = torchmetrics.F1()
metric_pre = torchmetrics.Precision()
metric_rec = torchmetrics.Recall()

n_batches = 3
for i in range(n_batches):
    # simulate a classification problem
    preds = torch.randn(10, 5).softmax(dim=-1)
    target = torch.randint(5, (10,))

    acc = metric_acc(preds, target)
    f1 = metric_f1(preds, target)
    pre = metric_pre(preds, target)
    rec = metric_rec(preds, target)
    print(f"Accuracy on batch {i}: {acc}")
    print(f"F1 score on batch {i}: {f1}")
    print(f"pre score on batch {i}: {pre}")
    print(f"rec score on batch {i}: {rec}")
    print('-' * 20)


acc = metric_acc.compute()
f1 = metric_f1.compute()
pre = metric_pre.compute()
rec = metric_rec.compute()
print(f"Accuracy on all data: {acc}")
print(f"f1 score on all data: {f1}")
print(f"pre score on all data: {pre}")
print(f"rec score on all data: {rec}")

结果来了。

Accuracy on batch 0: 0.10000000149011612
F1 score on batch 0: 0.10000000894069672
pre score on batch 0: 0.10000000149011612
rec score on batch 0: 0.10000000149011612
--------------------
Accuracy on batch 1: 0.30000001192092896
F1 score on batch 1: 0.30000001192092896
pre score on batch 1: 0.30000001192092896
rec score on batch 1: 0.30000001192092896
--------------------
Accuracy on batch 2: 0.4000000059604645
F1 score on batch 2: 0.40000003576278687
pre score on batch 2: 0.4000000059604645
rec score on batch 2: 0.4000000059604645
--------------------
Accuracy on all data: 0.2666666805744171
f1 score on all data: 0.2666666805744171
pre score on all data: 0.2666666805744171
rec score on all data: 0.2666666805744171

Process finished with exit code 0

当我将它与 PyTorchLightning 一起使用时,我得到了相同的结果,所以我用简单的代码尝试并得到了同样的结果。
如果您知道问题或解决方案,请告诉我。
非常感谢。

我猜您正在寻找有关如何使用 TorchMetrics 通过 PytorchLightning 记录训练进度的 simple example。否则,您能否更详细地说明您的用例,最好添加您的应用程序示例?

这样做的原因是,如果您使用 F1、Precision、ACC 和 Recall with micro(默认),对于 multi class classification,这些是 equivalent metrics 并推荐你应该使用宏

metric_acc = torchmetrics.Accuracy(average='macro')
metric_f1 = torchmetrics.F1(average='macro')
metric_pre = torchmetrics.Precision(average='macro')
metric_rec = torchmetrics.Recall(average='macro')