比较两个分割图预测

Compare two segmentation maps predictions

我在未标记数据上使用两个预测分割图之间的一致性。对于标记数据,我使用 nn.BCEwithLogitsLoss 和 Dice Loss.

我正在制作视频,这就是 5 维输出的原因。 (batch_size,通道,帧,高度,宽度)

我想知道我们如何比较两个预测的分割图。

gmentation maps.

# gt_seg - Ground truth segmentation map. - (8, 1, 8, 112, 112)
# aug_gt_seg - Augmented ground truth segmentation map - (8, 1, 8, 112, 112)

predicted_seg_1 = model(data, targets)       # (8, 1, 8, 112, 112)
predicted_seg_2 = model(augmented_data, augmented_targets) #(8, 1, 8, 112, 112)

# define criterion
seg_criterion_1 = nn.BCEwithLogitsLoss(size_average=True)
seg_criterion_2 = nn.DiceLoss()

# labeled losses
supervised_loss_1 = seg_criterion_1(predicted_seg_1, gt_seg)
supervised_loss_2 = seg_criterion_2(predicted_seg_1, gt_seg)

# Consistency loss
if consistency_loss == "l2":
      consistency_criterion = nn.MSELoss()
      cons_loss = consistency_criterion(predicted_gt_seg_1, predicted_gt_seg_2)

elif consistency_loss == "l1":
      consistency_criterion = nn.L1Loss()
      cons_loss = consistency_criterion(predicted_gt_seg_1, predicted_gt_seg_2)

total_supervised_loss = supervised_loss_1 + supervised_loss_2
total_consistency_loss = cons_loss

这是在两个预测分割图之间应用一致性的正确方法吗?

我主要是火炬网站上的定义搞糊涂了。这是输入 x 与目标 y 的比较。我认为它看起来是正确的,因为我希望两个预测的分割图相似。但是,第二个分割图不是目标。这就是为什么我很困惑。因为如果这可能是有效的,那么每个损失函数都可以以某种或另一种方式应用。这看起来对我没有吸引力。如果是正确的比较方式,能否推广到其他基于分割的损失如Dice Loss、IoU Loss等?

关于标记数据损失计算的另一个问题:

# gt_seg - Ground truth segmentation map
# aug_gt_seg - Augmented ground truth segmentation map

predicted_seg_1 = model(data, targets)
predicted_seg_2 = model(augmented_data, augmented_targets)

# define criterion
seg_criterion_1 = nn.BCEwithLogitsLoss(size_average=True)
seg_criterion_2 = nn.DiceLoss()

# labeled losses
supervised_loss_1 = seg_criterion_1(predicted_seg_1, gt_seg)
supervised_loss_2 = seg_criterion_2(predicted_seg_1, gt_seg)

# augmented labeled losses
aug_supervised_loss_1 = seg_criterion_1(predicted_seg_2, aug_gt_seg)
aug_supervised_loss_2 = seg_criterion_2(predicted_seg_2, aug_gt_seg)

total_supervised_loss = supervised_loss_1 + supervised_loss_2 + aug_supervised_loss_1 + aug_supervised_loss_2

total_supervised_loss的计算是否正确?我可以在此应用 loss.backward() 吗?

是的,这是实现一致性丢失的有效方法。 pytorch 文档使用的命名法列出了一个输入作为目标,另一个作为预测,但考虑到 L1、L2、Dice 和 IOU 损失都是对称的(即 Loss(a,b) = Loss(b,a)) .因此,这些函数中的任何一个都将实现一种形式的一致性损失,而不管一个输入实际上是基本事实还是“目标”。