尝试将 PyTorch 模型导出到 ONNX 时崩溃:forward() 缺少 1 个必需的位置参数

Crash when trying to export PyTorch model to ONNX: forward() missing 1 required positional argument

我正在尝试将 pyTorch 模型转换为 onnx,如下所示:

torch.onnx.export(
  model=modnet.module,
  args=example_input, 
  f=ONNX_PATH, # where should it be saved
  verbose=False,
  export_params=True,
  do_constant_folding=False,
  input_names=['input'],
  output_names=['output']
)

modnet 是来自此 repo 的模型:https://github.com/ZHKKKe/MODNet

example_input 是形状为 [1, 3, 512, 512]

的张量

在转换过程中我收到了错误:

TypeError: forward() missing 1 required positional argument: 'inference'

这是我克隆的 Colab 笔记本,用于重现异常:https://colab.research.google.com/drive/1AE1VAXIXkm26krIOoBaFfhoE53hhuEdf?usp=sharing

请救救我! :)

Modnet 转发方法需要一个名为 inference 的参数,它是一个布尔值,实际上在训练模型时他们以这种方式传递它:

# forward the main model
pred_semantic, pred_detail, pred_matte = modnet(image, False)

所以在这里你要做的就是像这样修改你的example_input

example_input = (example_input, True)