RuntimeError: Trying to backward through the graph a second time, but the buffers have already been freed pytorch

RuntimeError: Trying to backward through the graph a second time, but the buffers have already been freed pytorch

如何在第二次调用.backward()之前清除渐变。

RuntimeError: Trying to backward through the graph a second time, but the saved intermediate results have already been freed. Specify retain_graph=True when calling backward the first time.

a = torch.tensor([2.0], requires_grad = True)
b = torch.tensor([2.0], requires_grad = True)
d = torch.tensor([2.0], requires_grad = True)
c=a*b
c.backward()
e = d*e
e.backward(retain_graph=True)

我试过这样做:c.zero_grad() 但我收到错误 c 没有方法 zero_grad()

如错误消息所示,您需要在第一次 .backward 调用而不是第二次调用时指定 retain_graph=True 选项:

c.backward(retain_graph=True)
e = d*c
e.backward()

如果不保留图形,第二次反向传递将无法到达节点 cab,因为激活将被清除通过第一次向后传球。