使用来自 torchmetrics 的 F1Score 的不切实际的结果
Unrealistic results using F1Score from torchmetrics
我已经为 Pytorch Lightning 中的二元分类问题训练了一个分割 NN 模型。为了实现这一点,我使用了 BCEWithLogitsLoss。我的基本事实和预测的形状都是 (BZ, 640, 256) 它们的内容分别是 (0, 1) [0, 1].
现在,我正在尝试使用来自 torchmetrics 的 F1Score 在我的验证数据集上计算 F1 分数,然后使用 pytroch lightning 的 log_dict by
进行累积
from torchmetrics import F1Score
self.f1 = F1Score(num_classes=2)
我的验证步骤如下所示:
def validation_step(self, batch, batch_idx):
t0, t1, mask_gt = batch
mask_pred = self.forward(t0, t1)
mask_pred = torch.sigmoid(mask_pred).squeeze()
mask_pred = torch.where(mask_pred > 0.5, 1, 0)
f1_score_ = self.f1(mask_pred, mask_gt)
metrics = {
'val_f1_score': f1_score_,
}
self.log_dict(metrics, on_epoch=True)
这让我在每个 epoch 结束时得到了高得离谱的 F1 分数(即使在训练开始前的健全性验证检查中),~0.99,这让我觉得我没有将 F1Score 与 [=21= 一起使用] 正确的方式。我已经尝试了几个参数(https://github.com/PyTorchLightning/metrics/blob/master/torchmetrics/classification/f_beta.py#L181-L310)但没有成功。我做错了什么?
原来我有一个极度不平衡的数据集。我只对 class 1 的 f1 分数感兴趣,忽略了 class 0(gt 中的大多数情况)。我必须按以下方式配置 F1Score:
# ignore_index=0 --> excludes class 0 from f1 score calculation
F1Score(ignore_index=0, num_classes=2, average='macro', mdmc_average='samplewise')
有效步骤看起来一样
我已经为 Pytorch Lightning 中的二元分类问题训练了一个分割 NN 模型。为了实现这一点,我使用了 BCEWithLogitsLoss。我的基本事实和预测的形状都是 (BZ, 640, 256) 它们的内容分别是 (0, 1) [0, 1].
现在,我正在尝试使用来自 torchmetrics 的 F1Score 在我的验证数据集上计算 F1 分数,然后使用 pytroch lightning 的 log_dict by
进行累积from torchmetrics import F1Score
self.f1 = F1Score(num_classes=2)
我的验证步骤如下所示:
def validation_step(self, batch, batch_idx):
t0, t1, mask_gt = batch
mask_pred = self.forward(t0, t1)
mask_pred = torch.sigmoid(mask_pred).squeeze()
mask_pred = torch.where(mask_pred > 0.5, 1, 0)
f1_score_ = self.f1(mask_pred, mask_gt)
metrics = {
'val_f1_score': f1_score_,
}
self.log_dict(metrics, on_epoch=True)
这让我在每个 epoch 结束时得到了高得离谱的 F1 分数(即使在训练开始前的健全性验证检查中),~0.99,这让我觉得我没有将 F1Score 与 [=21= 一起使用] 正确的方式。我已经尝试了几个参数(https://github.com/PyTorchLightning/metrics/blob/master/torchmetrics/classification/f_beta.py#L181-L310)但没有成功。我做错了什么?
原来我有一个极度不平衡的数据集。我只对 class 1 的 f1 分数感兴趣,忽略了 class 0(gt 中的大多数情况)。我必须按以下方式配置 F1Score:
# ignore_index=0 --> excludes class 0 from f1 score calculation
F1Score(ignore_index=0, num_classes=2, average='macro', mdmc_average='samplewise')
有效步骤看起来一样