PyTorch backward() 在受其他元素中的 nan 影响的张量元素上
PyTorch backward() on a tensor element affected by nan in other elements
考虑以下两个示例:
x = torch.tensor(1., requires_grad=True)
y = torch.tensor(0., requires_grad=True)
z = torch.full((2, ), float("nan"))
z0 = x * y / y
z1 = x + y
print(z0, z1) # tensor(nan, grad_fn=<DivBackward0>) tensor(1., grad_fn=<AddBackward0>)
z1.backward()
print(x.grad) # tensor(1.)
x = torch.tensor(1., requires_grad=True)
y = torch.tensor(0., requires_grad=True)
z = torch.full((2, ), float("nan"))
z[0] = x * y / y
z[1] = x + y
print(z) # tensor([nan, 1.], grad_fn=<CopySlices>)
z[1].backward()
print(x.grad) # tensor(nan)
例1中,z0
不影响z1
,z1
的backward()
按预期执行,x.grad
不是nan。但是在例子2中,z[1]
的backward()
好像是受到了z[0]
的影响,而x.grad
就是nan
.
如何防止这种情况(示例 1 是所需的行为)?具体来说,我需要在 z[0]
中保留 nan
,因此将 epsilon 添加到除法中没有帮助。
当 indexing 赋值中的张量时,PyTorch 访问张量的所有元素(它在引擎盖下使用二进制乘法掩码来保持可微性),这就是它获取 nan
的另一个元素(因为 0*nan -> nan
)。
我们可以在计算图中看到这一点:
torchviz.make_dot(z1, params={'x':x,'y':y})
torchviz.make_dot(z[1], params={'x':x,'y':y})
如果您希望避免这种行为,mask the nan
's,或者像您在第一个示例中所做的那样 - 将它们分成两个不同的对象。
考虑以下两个示例:
x = torch.tensor(1., requires_grad=True)
y = torch.tensor(0., requires_grad=True)
z = torch.full((2, ), float("nan"))
z0 = x * y / y
z1 = x + y
print(z0, z1) # tensor(nan, grad_fn=<DivBackward0>) tensor(1., grad_fn=<AddBackward0>)
z1.backward()
print(x.grad) # tensor(1.)
x = torch.tensor(1., requires_grad=True)
y = torch.tensor(0., requires_grad=True)
z = torch.full((2, ), float("nan"))
z[0] = x * y / y
z[1] = x + y
print(z) # tensor([nan, 1.], grad_fn=<CopySlices>)
z[1].backward()
print(x.grad) # tensor(nan)
例1中,z0
不影响z1
,z1
的backward()
按预期执行,x.grad
不是nan。但是在例子2中,z[1]
的backward()
好像是受到了z[0]
的影响,而x.grad
就是nan
.
如何防止这种情况(示例 1 是所需的行为)?具体来说,我需要在 z[0]
中保留 nan
,因此将 epsilon 添加到除法中没有帮助。
当 indexing 赋值中的张量时,PyTorch 访问张量的所有元素(它在引擎盖下使用二进制乘法掩码来保持可微性),这就是它获取 nan
的另一个元素(因为 0*nan -> nan
)。
我们可以在计算图中看到这一点:
torchviz.make_dot(z1, params={'x':x,'y':y}) |
torchviz.make_dot(z[1], params={'x':x,'y':y}) |
---|---|
如果您希望避免这种行为,mask the nan
's,或者像您在第一个示例中所做的那样 - 将它们分成两个不同的对象。