PyTorch backward() 在受其他元素中的 nan 影响的张量元素上

PyTorch backward() on a tensor element affected by nan in other elements

考虑以下两个示例:

x = torch.tensor(1., requires_grad=True)
y = torch.tensor(0., requires_grad=True)
z = torch.full((2, ), float("nan"))
z0 = x * y / y
z1 = x + y
print(z0, z1) # tensor(nan, grad_fn=<DivBackward0>) tensor(1., grad_fn=<AddBackward0>)
z1.backward()
print(x.grad) # tensor(1.)


x = torch.tensor(1., requires_grad=True)
y = torch.tensor(0., requires_grad=True)
z = torch.full((2, ), float("nan"))
z[0] = x * y / y
z[1] = x + y
print(z) # tensor([nan, 1.], grad_fn=<CopySlices>)
z[1].backward()
print(x.grad) # tensor(nan)

例1中,z0不影响z1z1backward()按预期执行,x.grad不是nan。但是在例子2中,z[1]backward()好像是受到了z[0]的影响,而x.grad就是nan.

如何防止这种情况(示例 1 是所需的行为)?具体来说,我需要在 z[0] 中保留 nan,因此将 epsilon 添加到除法中没有帮助。

indexing 赋值中的张量时,PyTorch 访问张量的所有元素(它在引擎盖下使用二进制乘法掩码来保持可微性),这就是它获取 nan 的另一个元素(因为 0*nan -> nan)。

我们可以在计算图中看到这一点:

torchviz.make_dot(z1, params={'x':x,'y':y}) torchviz.make_dot(z[1], params={'x':x,'y':y})

如果您希望避免这种行为,mask the nan's,或者像您在第一个示例中所做的那样 - 将它们分成两个不同的对象。