如何在 Flux.jl 中使用 .pth 模型?
How to use a .pth model in Flux.jl?
我有一个用 PyTorch 训练的模型,保存为 .pth 格式。是否可以在 Flux.jl 中使用和加载该模型?我环顾四周,但在 Flux 文档的任何地方都没有看到这一点。
我能想到的唯一办法就是
- 通过
将.pth
转换为.onnx
import torch.onnx
import torchvision
import torch
dummy_input = #...
model = #...
state_dict = torch.load('model.pth')
model.load_state_dict(state_dict)
torch.onnx.export(model, dummy_input, "model.onnx")
- 使用 ONNX.jl 从
.onnx
加载模型。该库目前似乎正在重建中,但旧的 API 可能适合您。仔细检查了一下,貌似加载模型后可能有偏差
另外,这个讨论是相关的:https://github.com/FluxML/ML-Coordination-Tracker/issues/10
我有一个用 PyTorch 训练的模型,保存为 .pth 格式。是否可以在 Flux.jl 中使用和加载该模型?我环顾四周,但在 Flux 文档的任何地方都没有看到这一点。
我能想到的唯一办法就是
- 通过 将
.pth
转换为.onnx
import torch.onnx
import torchvision
import torch
dummy_input = #...
model = #...
state_dict = torch.load('model.pth')
model.load_state_dict(state_dict)
torch.onnx.export(model, dummy_input, "model.onnx")
- 使用 ONNX.jl 从
.onnx
加载模型。该库目前似乎正在重建中,但旧的 API 可能适合您。仔细检查了一下,貌似加载模型后可能有偏差
另外,这个讨论是相关的:https://github.com/FluxML/ML-Coordination-Tracker/issues/10