保存 PyTorch 模型以转换为 ONNX

Save PyTorch model for conversion to ONNX

我是 Pytorch 的新手(和 Python),我已经按照本指南训练了一个模型,然后将权重保存到 pth 文件中: https://medium.com/@alexppppp/how-to-create-synthetic-dataset-for-computer-vision-keypoint-detection-78ba481cdafd

我的理解是,要将模型转换为 ONNX,您需要保存整个模型,而不仅仅是权重。

我想相关的代码是这样的:

for epoch in range(num_epochs):
    train_one_epoch(model, optimizer, data_loader_train, device, epoch, print_freq=1000)
    lr_scheduler.step()
    evaluate(model, data_loader_test, device)
    
# Save model weights after training
torch.save(model.state_dict(), 'keypointsrcnn_weights.pth')

是否有一种简单的方法来保存“整个”模型而不仅仅是权重?我在文档中看到过这个,但这看起来需要在纪元循环内而不是在训练之后?

torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': loss,
            ...
            }, PATH)

请原谅我完全不理解。我的目的是尝试将 PyTorch 模型转换为 ONNX。

使用torch.onnx.export。应该看起来像

  arch = models.alexnet();      pic_x = 227
  dummy_input = torch.zeros((1,3, pic_x, pic_x))
  torch.onnx.export(arch, dummy_input, "alexnet.onnx", verbose=True, export_params=True, )

graph(%input.1 : Float(1, 3, 227, 227, strides=[154587, 51529, 227, 1], requires_grad=0, device=cpu),
      %features.0.weight : Float(64, 3, 11, 11, strides=[363, 121, 11, 1], requires_grad=1, device=cpu),
      %features.0.bias : Float(64, strides=[1], requires_grad=1, device=cpu),
      %features.3.weight : Float(192, 64, 5, 5, strides=[1600, 25, 5, 1], requires_grad=1, device=cpu),
      %features.3.bias : Float(192, strides=[1], requires_grad=1, device=cpu),
      ...
      %classifier.6.weight : Float(1000, 4096, strides=[4096, 1], requires_grad=1, device=cpu),
      %classifier.6.bias : Float(1000, strides=[1], requires_grad=1, device=cpu)):
  %17 : Float(1, 64, 56, 56, strides=[200704, 3136, 56, 1], requires_grad=1, device=cpu) = onnx::Conv[dilations=[1, 1], group=1, kernel_shape=[11, 11], pads=[2, 2, 2, 2], strides=[4, 4]](%input.1, %features.0.weight, %features.0.bias) # c:\python39\lib\site-packages\torch\nn\modules\conv.py:442:0
  %18 : Float(1, 64, 56, 56, strides=[200704, 3136, 56, 1], requires_grad=1, device=cpu) = onnx::Relu(%17) # c:\python39\lib\site-packages\torch\nn\functional.py:1297:0
  ...