这个自定义 PyTorch 损失函数是否可微

Is this custom PyTorch loss function differentiable

我有一个 PyTorch 损失的自定义 forward 实现。培训效果很好。我已经检查了 loss.grad_fn,它不是 None。 我想了解两件事:

  1. 由于在从输入到输出的路径上存在 if-else 语句,因此该函数如何可微?

  2. gt(ground truth input)到loss(output)的路径是否需要可微?或者仅来自 pred(预测输入)的路径?

这里是源代码:

class FocalLoss(nn.Module):
    def __init__(self):
        super(FocalLoss, self).__init__()

    def forward(self, pred, gt):
        pos_inds = gt.eq(1).float()
        neg_inds = gt.lt(1).float()
        neg_weights = torch.pow(1 - gt, 4)

        pos_loss = torch.log(pred) * torch.pow(1 - pred, 2) * pos_inds
        neg_loss = torch.log(1 - pred) * torch.pow(pred, 2) * neg_weights * neg_inds

        num_pos = pos_inds.float().sum()
        pos_loss_s = pos_loss.sum()
        neg_loss_s = neg_loss.sum()
        if num_pos == 0:
            loss = - neg_loss_s
        else:
            loss = - (pos_loss_s + neg_loss_s) / num_pos

        return loss

if 语句不是计算图的一部分。它是用于动态构建此图的代码的一部分(即 forward 函数),但它本身并不是其中的一部分。要遵循的原则是问问自己是否使用 grad_fn 回溯到图的叶子(图中没有父项的张量, 输入和参数)每个节点的回调,通过图形反向传播。答案是,只有当每个运算符都是可微的时,您才能这样做:在编程术语中,它们实现了向后函数运算(a.k.a. grad_fn)。

  1. 在您的示例中,num_pos 是否等于 0,由此产生的损失张量将单独取决于 neg_loss_spos_loss_sneg_loss_s。然而,在任何一种情况下,生成的 loss 张量仍然附加到输入 pred:

    • 通过一种方式:“neg_loss_s”节点
    • 其他:“pos_loss_s”和“neg_loss_s”节点。

在您的设置中,无论哪种方式,操作都是可微分的。

  1. 如果 gt 是一个 ground-truth 张量那么它不需要梯度并且从它到最终损失的操作不需要是可微的。在您的示例中就是这种情况,其中 pos_indsneg_inds 都是 non-differientblae 因为它们是布尔运算符。

PyTorch 计算梯度 w.r.t 损失函数本身。 PyTorch 记录在 forward 遍期间执行的标准数学运算序列,例如对数、求幂、乘法、加法等,并计算它们的梯度 w.r.t 这些 数学运算 当调用 backward() 时。因此,如果您仅使用标准数学运算来计算损失,if-else 条件的存在对 PyTorch 无关紧要。