RuntimeError: Trying to backward through the graph a second time, but the saved intermediate results have already been freed

RuntimeError: Trying to backward through the graph a second time, but the saved intermediate results have already been freed

我收到如标题所示的错误....我找到了一些答案,所以我尝试 retain_graph=True,但它不起作用。也许我的代码有其他问题(它发生在 loss_actor.backward(retain_grah....)

q = torch.zeros(len(reward))
q_target = torch.zeros(len(reward))
for j, r in enumerate(reward):
    q_target[j] = self.critic_network(torch.transpose(next_state[j], 0, 1), self.actor_network(torch.transpose(next_state[j], 0, 1)).view(1, 1))
    q_target[j] = r + (done[j] * gamma * q_target[j]).detach()
    q[j] = self.critic_network(torch.transpose(state[j], 0, 1), action[j].view(1, 1))
loss_critic = F.mse_loss(q, q_target)
self.critic_optimizer.zero_grad()
loss_critic.backward()
self.critic_optimizer.step()

b = torch.zeros(len(reward))
for j, r in enumerate(reward):
    b[j] = self.critic_network(torch.transpose(state[j], 0, 1), self.actor_network(torch.transpose(state[j], 0, 1)).view(1, 1))
loss_actor = -torch.mean(b)
self.actor_optimizer.zero_grad()
loss_actor.backward(retain_graph=True)
self.actor_optimizer.step()

根据提供的有关您的部分计算图的信息,我假设 loss_actorloss_critic 共享其中的某些部分,我认为它 state(不确定)

state -> q --> loss_critic <-- backward 1
|
-------> b --> loss_actor <--- backward 2

重现您的示例:

# Some computations that produce state
state = torch.ones((2, 2), requires_grad=True) ** 2

# Compute the first loss
q = torch.zeros((1))
q[0] = state[0, 0]
l1 = torch.sum(2 * q)
l1.backward()

# Compute the second loss
b = torch.zeros((1))
b[0] = state[1, 1]
l2 = torch.mean(2 * b)
l2.backward()
RuntimeError                              Traceback (most recent call last)
<ipython-input-28-2ab509bedf7a> in <module>
     10 b[0] = state[1, 1]
     11 l2 = torch.mean(2 * b)
---> 12 l2.backward()
RuntimeError: Trying to backward through the graph a second time, but the saved intermediate results have already been freed. Specify retain_graph=True when calling backward the first time.

正在尝试

...
l2.backward(retain_graph=True)

没有帮助,因为你必须

Specify retain_graph=True when calling backward the first time.

此处,在第一次向后调用时(l1

l1.backward(retain_graph=True)