我的模型前向部分的输入是一个元组,不能转换成onnx格式?

The input of the forward part of my model is a tuple, cannot be converted to onnx format?

测试代码:

    #!/usr/bin/env python
    # -*- coding:utf-8 -*-
    import torch
    import torch.nn as nn


    class Model(nn.Module):
        def __init__(self):
            super(Model, self).__init__()
            self.linear = nn.Linear(32, 16)
            self.relu1 = nn.ReLU(inplace=True)
            self.relu2 = nn.ReLU(inplace=True)
            self.fc = nn.Linear(32, 2)

        def forward(self, x):
            x1, x2 = x
            x1 = self.linear(x1)
            x1 = self.relu1(x1)
            x2 = self.linear(x2)
            x2 = self.relu2(x2)
            out = torch.cat((x1, x2), dim=-1)
            out = self.fc(out)
            return out


    model = Model()
    model.eval()

    x1 = torch.randn((2, 10, 32))
    x2 = torch.randn((2, 10, 32))
    x = (x1, x2)

    torch.onnx.export(model,
                  x,
                  'model.onnx',
                  input_names=["input"],
                  output_names=["output"],
                  dynamic_axes={'input': {0: 'batch'}, 'output': {0: 'batch'}}
                  )
    print("Done")

如何把上面的代码转换成onnx? 我的模型前向部分的输入是一个元组,不能转换成onnx格式? 谢谢! 我的模型前向部分的输入是一个元组,按照现有的方法是无法转换成onnx格式的。你能告诉我怎么解决吗

正在查看this issue and this other issue, the parameters are unpacked by default so you need to provide a tuple as argument to torch.onnx.export

torch.onnx.export(model,
   args=(x,),
   f='model.onnx',
   input_names=["input"],
   output_names=["output"],
   dynamic_axes={'input': {0: 'batch'}, 'output': {0: 'batch'}})