pytorch 是否对其计算图进行急切修剪?
Does pytorch do eager pruning of its computational graph?
这是一个非常简单的例子:
import torch
x = torch.tensor([1., 2., 3., 4., 5.], requires_grad=True)
y = torch.tensor([2., 2., 2., 2., 2.], requires_grad=True)
z = torch.tensor([1., 1., 0., 0., 0.], requires_grad=True)
s = torch.sum(x * y * z)
s.backward()
print(x.grad)
这将打印,
tensor([2., 2., 0., 0., 0.]),
因为对于 z 为零的条目,ds/dx 当然是零。
我的问题是:pytorch 是否智能并在达到零时停止计算?或者实际上是在计算“2*5
”,只是为了稍后做“10 * 0 = 0
”?
在这个简单的例子中,它并没有太大的不同,但在我正在研究的(更大的)问题中,这会产生很大的不同。
感谢您的任何意见。
不,pytorch 不会在达到零时修剪任何后续计算。更糟糕的是,由于 float 算法的工作原理,所有后续的零乘法将花费与任何常规乘法大致相同的时间。
对于某些情况,虽然有一些解决方法,例如,如果你想使用掩码损失,你可以 将掩码输出设置 为零,或者将它们从梯度。
这个例子说明了区别:
def time_backward(do_detach):
x = torch.tensor(torch.rand(100000000), requires_grad=True)
y = torch.tensor(torch.rand(100000000), requires_grad=True)
s2 = torch.sum(x * y)
s1 = torch.sum(x * y)
if do_detach:
s2 = s2.detach()
s = s1 + 0 * s2
t = time.time()
s.backward()
print(time.time() - t)
time_backward(do_detach= False)
time_backward(do_detach= True)
输出:
0.502875089645
0.198422908783
这是一个非常简单的例子:
import torch
x = torch.tensor([1., 2., 3., 4., 5.], requires_grad=True)
y = torch.tensor([2., 2., 2., 2., 2.], requires_grad=True)
z = torch.tensor([1., 1., 0., 0., 0.], requires_grad=True)
s = torch.sum(x * y * z)
s.backward()
print(x.grad)
这将打印,
tensor([2., 2., 0., 0., 0.]),
因为对于 z 为零的条目,ds/dx 当然是零。
我的问题是:pytorch 是否智能并在达到零时停止计算?或者实际上是在计算“2*5
”,只是为了稍后做“10 * 0 = 0
”?
在这个简单的例子中,它并没有太大的不同,但在我正在研究的(更大的)问题中,这会产生很大的不同。
感谢您的任何意见。
不,pytorch 不会在达到零时修剪任何后续计算。更糟糕的是,由于 float 算法的工作原理,所有后续的零乘法将花费与任何常规乘法大致相同的时间。
对于某些情况,虽然有一些解决方法,例如,如果你想使用掩码损失,你可以 将掩码输出设置 为零,或者将它们从梯度。
这个例子说明了区别:
def time_backward(do_detach):
x = torch.tensor(torch.rand(100000000), requires_grad=True)
y = torch.tensor(torch.rand(100000000), requires_grad=True)
s2 = torch.sum(x * y)
s1 = torch.sum(x * y)
if do_detach:
s2 = s2.detach()
s = s1 + 0 * s2
t = time.time()
s.backward()
print(time.time() - t)
time_backward(do_detach= False)
time_backward(do_detach= True)
输出:
0.502875089645
0.198422908783