如果你有多个神经网络,PyTorch 如何知道训练损失应该传播回哪个神经网络?

How does PyTorch know to which neural network the training loss shall be propagated back if you have multiple neural networks?

我想借助另外两个已经过训练和测试的神经网络来训练一个神经网络。我要训练的网络的输入同时输入到第一个静态网络。我要训练的网络的输出被输入到第二个静态网络。损失应在静态网络的输出上计算并传播回列车网络。

# Initialization
var_model_statemapper = NeuralNetwork(9, [('linear', 9), ('relu', None), ('dropout', 0.2), ('linear', 8)])

var_model_panda = NeuralNetwork(9, [('linear', 9), ('relu', None), ('dropout', 0.2), ('linear', 27)])
var_model_panda.load_state_dict(torch.load("panda.pth"))

var_model_ur5 = NeuralNetwork(8, [('linear', 8), ('relu', None), ('dropout', 0.2), ('linear', 24)])
var_model_ur5.load_state_dict(torch.load("ur5.pth"))

var_loss_function = torch.nn.MSELoss()
var_optimizer = torch.optim.Adam(var_model_statemapper.parameters(), lr=0.001)

# Forward Propagation
var_panda_output = var_model_panda(var_statemapper_input)
var_ur5_output = var_model_ur5(var_statemapper_output)
var_train_loss = var_loss_function(var_panda_output, var_ur5_output)

# Backward Propagation
var_optimizer.zero_grad()
var_train_loss.backward()
var_optimizer.step()

可以看到“var_model_statemapper”是要训练的网络。网络“var_model_panda”和“var_model_ur5”已初始化,它们的 state_dicts 正在从相应的“.pth”文件中读取,因此这些网络需要是静态的。我的主要问题是,哪个网络在反向传播中被更新?只是“var_model_statemapper”还是所有网络?如果“var_model_statemapper”没有更新,我该如何实现? PyTorch 是否仅通过优化器的初始化就知道要更新哪个网络?

正式化您的管道以更好地了解设置:

x --- | state_mapper | --> y --- | ur5 | --> ur5_out
 \                                              |
  \                                             ↓
   \--- | panda | --> panda_out ----------- | loss_fn | --> loss

这是您提供的行发生的情况:

var_optimizer.zero_grad()  # 0.
var_train_loss.backward()  # 1.
var_optimizer.step()       # 2.
  1. 在优化器上调用 zero_grad 将清除该优化器中包含的所有参数梯度的缓存。在您的情况下,您已经 var_optimizer 注册了来自 var_model_statemapper 的参数(您要优化的模型)。

  2. 当您推断损失并通过 backward 调用对其进行反向传播时,梯度将通过所有三个模型的参数传播。

  3. 然后在优化器上调用 step 将更新在您调用它的优化器中注册的参数。在您的情况下,这意味着 var_optimizer.step() 将更新模型的所有参数 var_model_statemapper 单独 使用在步骤 1.[=51] 中计算的梯度=](即在 var_train_loss 上使用 backward 调用)。

总而言之,你目前的做法只会更新var_model_statemapper的参数。理想情况下,您可以通过将参数的 requires_grad 标志设置为 False 来冻结模型 var_model_pandavar_model_ur5。这将节省推理和训练的速度,因为在反向传播期间不会计算和存储它们的梯度。