这个自定义 PyTorch 损失函数是否可微
Is this custom PyTorch loss function differentiable
我有一个 PyTorch 损失的自定义 forward
实现。培训效果很好。我已经检查了 loss.grad_fn
,它不是 None
。
我想了解两件事:
由于在从输入到输出的路径上存在 if
-else
语句,因此该函数如何可微?
从gt
(ground truth input)到loss(output)的路径是否需要可微?或者仅来自 pred
(预测输入)的路径?
这里是源代码:
class FocalLoss(nn.Module):
def __init__(self):
super(FocalLoss, self).__init__()
def forward(self, pred, gt):
pos_inds = gt.eq(1).float()
neg_inds = gt.lt(1).float()
neg_weights = torch.pow(1 - gt, 4)
pos_loss = torch.log(pred) * torch.pow(1 - pred, 2) * pos_inds
neg_loss = torch.log(1 - pred) * torch.pow(pred, 2) * neg_weights * neg_inds
num_pos = pos_inds.float().sum()
pos_loss_s = pos_loss.sum()
neg_loss_s = neg_loss.sum()
if num_pos == 0:
loss = - neg_loss_s
else:
loss = - (pos_loss_s + neg_loss_s) / num_pos
return loss
if
语句不是计算图的一部分。它是用于动态构建此图的代码的一部分(即 forward
函数),但它本身并不是其中的一部分。要遵循的原则是问问自己是否使用 grad_fn
回溯到图的叶子(图中没有父项的张量,即 输入和参数)每个节点的回调,通过图形反向传播。答案是,只有当每个运算符都是可微的时,您才能这样做:在编程术语中,它们实现了向后函数运算(a.k.a. grad_fn
)。
在您的示例中,num_pos
是否等于 0
,由此产生的损失张量将单独取决于 neg_loss_s
或 pos_loss_s
和 neg_loss_s
。然而,在任何一种情况下,生成的 loss
张量仍然附加到输入 pred
:
- 通过一种方式:“
neg_loss_s
”节点
- 或其他:“
pos_loss_s
”和“neg_loss_s
”节点。
在您的设置中,无论哪种方式,操作都是可微分的。
- 如果
gt
是一个 ground-truth 张量那么它不需要梯度并且从它到最终损失的操作不需要是可微的。在您的示例中就是这种情况,其中 pos_inds
和 neg_inds
都是 non-differientblae 因为它们是布尔运算符。
PyTorch 不 计算梯度 w.r.t 损失函数本身。 PyTorch 记录在 forward
遍期间执行的标准数学运算序列,例如对数、求幂、乘法、加法等,并计算它们的梯度 w.r.t 这些 数学运算 当调用 backward()
时。因此,如果您仅使用标准数学运算来计算损失,if-else
条件的存在对 PyTorch 无关紧要。
我有一个 PyTorch 损失的自定义 forward
实现。培训效果很好。我已经检查了 loss.grad_fn
,它不是 None
。
我想了解两件事:
由于在从输入到输出的路径上存在
if
-else
语句,因此该函数如何可微?从
gt
(ground truth input)到loss(output)的路径是否需要可微?或者仅来自pred
(预测输入)的路径?
这里是源代码:
class FocalLoss(nn.Module):
def __init__(self):
super(FocalLoss, self).__init__()
def forward(self, pred, gt):
pos_inds = gt.eq(1).float()
neg_inds = gt.lt(1).float()
neg_weights = torch.pow(1 - gt, 4)
pos_loss = torch.log(pred) * torch.pow(1 - pred, 2) * pos_inds
neg_loss = torch.log(1 - pred) * torch.pow(pred, 2) * neg_weights * neg_inds
num_pos = pos_inds.float().sum()
pos_loss_s = pos_loss.sum()
neg_loss_s = neg_loss.sum()
if num_pos == 0:
loss = - neg_loss_s
else:
loss = - (pos_loss_s + neg_loss_s) / num_pos
return loss
if
语句不是计算图的一部分。它是用于动态构建此图的代码的一部分(即 forward
函数),但它本身并不是其中的一部分。要遵循的原则是问问自己是否使用 grad_fn
回溯到图的叶子(图中没有父项的张量,即 输入和参数)每个节点的回调,通过图形反向传播。答案是,只有当每个运算符都是可微的时,您才能这样做:在编程术语中,它们实现了向后函数运算(a.k.a. grad_fn
)。
在您的示例中,
num_pos
是否等于0
,由此产生的损失张量将单独取决于neg_loss_s
或pos_loss_s
和neg_loss_s
。然而,在任何一种情况下,生成的loss
张量仍然附加到输入pred
:- 通过一种方式:“
neg_loss_s
”节点 - 或其他:“
pos_loss_s
”和“neg_loss_s
”节点。
- 通过一种方式:“
在您的设置中,无论哪种方式,操作都是可微分的。
- 如果
gt
是一个 ground-truth 张量那么它不需要梯度并且从它到最终损失的操作不需要是可微的。在您的示例中就是这种情况,其中pos_inds
和neg_inds
都是 non-differientblae 因为它们是布尔运算符。
PyTorch 不 计算梯度 w.r.t 损失函数本身。 PyTorch 记录在 forward
遍期间执行的标准数学运算序列,例如对数、求幂、乘法、加法等,并计算它们的梯度 w.r.t 这些 数学运算 当调用 backward()
时。因此,如果您仅使用标准数学运算来计算损失,if-else
条件的存在对 PyTorch 无关紧要。