了解何时在 Pytorch 中使用 python 列表

Understanding when to use python list in Pytorch

基本上因为这个线程讨论 here,你不能使用 python 列表来包装你的子模块(例如你的层);否则,Pytorch 不会更新列表中子模块的参数。相反,您应该使用 nn.ModuleList 来包装您的子模块,以确保它们的参数将被更新。现在我也看到了像下面这样的代码,作者使用 python 列表来计算损失,然后使用 loss.backward() 来进行更新(在 RL 的强化算法中)。这是代码:

 policy_loss = []
    for log_prob in self.controller.log_probability_slected_action_list:
        policy_loss.append(- log_prob * (average_reward - b))
    self.optimizer.zero_grad()
    final_policy_loss = (torch.cat(policy_loss).sum()) * gamma
    final_policy_loss.backward()
    self.optimizer.step()

为什么使用这种格式的列表可以更新模块的参数,而第一种情况不起作用?我现在很困惑。如果我更改之前的代码 policy_loss = nn.ModuleList([]),它会抛出一个异常,说明张量浮点数不是子模块。

您误解了 Module 是什么。 Module 存储参数并定义前向传递的实现。

您可以使用张量和参数执行任意计算,从而产生其他新的张量。 Modules 不需要知道那些张量。您还可以将张量列表存储在 Python 列表中。当调用 backward 时,它需要在标量张量上,因此是串联的总和。这些张量是损失而不是参数,所以它们不应该是 Module 的属性,也不应该包含在 ModuleList.