RuntimeError: Trying to backward through the graph a second time, but the saved intermediate results have already been freed
RuntimeError: Trying to backward through the graph a second time, but the saved intermediate results have already been freed
我收到如标题所示的错误....我找到了一些答案,所以我尝试 retain_graph=True
,但它不起作用。也许我的代码有其他问题(它发生在 loss_actor.backward(retain_grah....)
)
q = torch.zeros(len(reward))
q_target = torch.zeros(len(reward))
for j, r in enumerate(reward):
q_target[j] = self.critic_network(torch.transpose(next_state[j], 0, 1), self.actor_network(torch.transpose(next_state[j], 0, 1)).view(1, 1))
q_target[j] = r + (done[j] * gamma * q_target[j]).detach()
q[j] = self.critic_network(torch.transpose(state[j], 0, 1), action[j].view(1, 1))
loss_critic = F.mse_loss(q, q_target)
self.critic_optimizer.zero_grad()
loss_critic.backward()
self.critic_optimizer.step()
b = torch.zeros(len(reward))
for j, r in enumerate(reward):
b[j] = self.critic_network(torch.transpose(state[j], 0, 1), self.actor_network(torch.transpose(state[j], 0, 1)).view(1, 1))
loss_actor = -torch.mean(b)
self.actor_optimizer.zero_grad()
loss_actor.backward(retain_graph=True)
self.actor_optimizer.step()
根据提供的有关您的部分计算图的信息,我假设 loss_actor
和 loss_critic
共享其中的某些部分,我认为它 state
(不确定)
state -> q --> loss_critic <-- backward 1
|
-------> b --> loss_actor <--- backward 2
重现您的示例:
# Some computations that produce state
state = torch.ones((2, 2), requires_grad=True) ** 2
# Compute the first loss
q = torch.zeros((1))
q[0] = state[0, 0]
l1 = torch.sum(2 * q)
l1.backward()
# Compute the second loss
b = torch.zeros((1))
b[0] = state[1, 1]
l2 = torch.mean(2 * b)
l2.backward()
RuntimeError Traceback (most recent call last)
<ipython-input-28-2ab509bedf7a> in <module>
10 b[0] = state[1, 1]
11 l2 = torch.mean(2 * b)
---> 12 l2.backward()
RuntimeError: Trying to backward through the graph a second time, but the saved intermediate results have already been freed. Specify retain_graph=True when calling backward the first time.
正在尝试
...
l2.backward(retain_graph=True)
没有帮助,因为你必须
Specify retain_graph=True when calling backward the first time.
此处,在第一次向后调用时(l1
)
l1.backward(retain_graph=True)
我收到如标题所示的错误....我找到了一些答案,所以我尝试 retain_graph=True
,但它不起作用。也许我的代码有其他问题(它发生在 loss_actor.backward(retain_grah....)
)
q = torch.zeros(len(reward))
q_target = torch.zeros(len(reward))
for j, r in enumerate(reward):
q_target[j] = self.critic_network(torch.transpose(next_state[j], 0, 1), self.actor_network(torch.transpose(next_state[j], 0, 1)).view(1, 1))
q_target[j] = r + (done[j] * gamma * q_target[j]).detach()
q[j] = self.critic_network(torch.transpose(state[j], 0, 1), action[j].view(1, 1))
loss_critic = F.mse_loss(q, q_target)
self.critic_optimizer.zero_grad()
loss_critic.backward()
self.critic_optimizer.step()
b = torch.zeros(len(reward))
for j, r in enumerate(reward):
b[j] = self.critic_network(torch.transpose(state[j], 0, 1), self.actor_network(torch.transpose(state[j], 0, 1)).view(1, 1))
loss_actor = -torch.mean(b)
self.actor_optimizer.zero_grad()
loss_actor.backward(retain_graph=True)
self.actor_optimizer.step()
根据提供的有关您的部分计算图的信息,我假设 loss_actor
和 loss_critic
共享其中的某些部分,我认为它 state
(不确定)
state -> q --> loss_critic <-- backward 1
|
-------> b --> loss_actor <--- backward 2
重现您的示例:
# Some computations that produce state
state = torch.ones((2, 2), requires_grad=True) ** 2
# Compute the first loss
q = torch.zeros((1))
q[0] = state[0, 0]
l1 = torch.sum(2 * q)
l1.backward()
# Compute the second loss
b = torch.zeros((1))
b[0] = state[1, 1]
l2 = torch.mean(2 * b)
l2.backward()
RuntimeError Traceback (most recent call last)
<ipython-input-28-2ab509bedf7a> in <module>
10 b[0] = state[1, 1]
11 l2 = torch.mean(2 * b)
---> 12 l2.backward()
RuntimeError: Trying to backward through the graph a second time, but the saved intermediate results have already been freed. Specify retain_graph=True when calling backward the first time.
正在尝试
...
l2.backward(retain_graph=True)
没有帮助,因为你必须
Specify retain_graph=True when calling backward the first time.
此处,在第一次向后调用时(l1
)
l1.backward(retain_graph=True)