如何在 Flux.jl 中使用 .pth 模型?

How to use a .pth model in Flux.jl?

我有一个用 PyTorch 训练的模型,保存为 .pth 格式。是否可以在 Flux.jl 中使用和加载该模型?我环顾四周,但在 Flux 文档的任何地方都没有看到这一点。

我能想到的唯一办法就是

  1. 通过
  2. .pth转换为.onnx
import torch.onnx
import torchvision
import torch

dummy_input = #...
model = #...
state_dict = torch.load('model.pth')
model.load_state_dict(state_dict)
torch.onnx.export(model, dummy_input, "model.onnx")
  1. 使用 ONNX.jl.onnx 加载模型。该库目前似乎正在重建中,但旧的 API 可能适合您。仔细检查了一下,貌似加载模型后可能有偏差

另外,这个讨论是相关的:https://github.com/FluxML/ML-Coordination-Tracker/issues/10