pytorch 是否对其计算图进行急切修剪?

Does pytorch do eager pruning of its computational graph?

这是一个非常简单的例子:

import torch

x = torch.tensor([1., 2., 3., 4., 5.], requires_grad=True)
y = torch.tensor([2., 2., 2., 2., 2.], requires_grad=True)
z = torch.tensor([1., 1., 0., 0., 0.], requires_grad=True)

s = torch.sum(x * y * z)
s.backward()

print(x.grad)

这将打印,

tensor([2., 2., 0., 0., 0.]),

因为对于 z 为零的条目,ds/dx 当然是零。

我的问题是:pytorch 是否智能并在达到零时停止计算?或者实际上是在计算“2*5”,只是为了稍后做“10 * 0 = 0”?

在这个简单的例子中,它并没有太大的不同,但在我正在研究的(更大的)问题中,这会产生很大的不同。

感谢您的任何意见。

不,pytorch 不会在达到零时修剪任何后续计算。更糟糕的是,由于 float 算法的工作原理,所有后续的零乘法将花费与任何常规乘法大致相同的时间。

对于某些情况,虽然有一些解决方法,例如,如果你想使用掩码损失,你可以 将掩码输出设置 为零,或者将它们从梯度。

这个例子说明了区别:

def time_backward(do_detach):
    x = torch.tensor(torch.rand(100000000), requires_grad=True)
    y = torch.tensor(torch.rand(100000000), requires_grad=True)
    s2 = torch.sum(x * y)
    s1 = torch.sum(x * y)
    if do_detach:
        s2 = s2.detach()
    s = s1 + 0 * s2
    t = time.time()
    s.backward()
    print(time.time() - t)

time_backward(do_detach= False)
time_backward(do_detach= True)

输出:

0.502875089645
0.198422908783