我的模型前向部分的输入是一个元组,不能转换成onnx格式?
The input of the forward part of my model is a tuple, cannot be converted to onnx format?
测试代码:
#!/usr/bin/env python
# -*- coding:utf-8 -*-
import torch
import torch.nn as nn
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
self.linear = nn.Linear(32, 16)
self.relu1 = nn.ReLU(inplace=True)
self.relu2 = nn.ReLU(inplace=True)
self.fc = nn.Linear(32, 2)
def forward(self, x):
x1, x2 = x
x1 = self.linear(x1)
x1 = self.relu1(x1)
x2 = self.linear(x2)
x2 = self.relu2(x2)
out = torch.cat((x1, x2), dim=-1)
out = self.fc(out)
return out
model = Model()
model.eval()
x1 = torch.randn((2, 10, 32))
x2 = torch.randn((2, 10, 32))
x = (x1, x2)
torch.onnx.export(model,
x,
'model.onnx',
input_names=["input"],
output_names=["output"],
dynamic_axes={'input': {0: 'batch'}, 'output': {0: 'batch'}}
)
print("Done")
如何把上面的代码转换成onnx?
我的模型前向部分的输入是一个元组,不能转换成onnx格式?
谢谢!
我的模型前向部分的输入是一个元组,按照现有的方法是无法转换成onnx格式的。你能告诉我怎么解决吗
正在查看this issue and this other issue, the parameters are unpacked by default so you need to provide a tuple as argument to torch.onnx.export
:
torch.onnx.export(model,
args=(x,),
f='model.onnx',
input_names=["input"],
output_names=["output"],
dynamic_axes={'input': {0: 'batch'}, 'output': {0: 'batch'}})
测试代码:
#!/usr/bin/env python
# -*- coding:utf-8 -*-
import torch
import torch.nn as nn
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
self.linear = nn.Linear(32, 16)
self.relu1 = nn.ReLU(inplace=True)
self.relu2 = nn.ReLU(inplace=True)
self.fc = nn.Linear(32, 2)
def forward(self, x):
x1, x2 = x
x1 = self.linear(x1)
x1 = self.relu1(x1)
x2 = self.linear(x2)
x2 = self.relu2(x2)
out = torch.cat((x1, x2), dim=-1)
out = self.fc(out)
return out
model = Model()
model.eval()
x1 = torch.randn((2, 10, 32))
x2 = torch.randn((2, 10, 32))
x = (x1, x2)
torch.onnx.export(model,
x,
'model.onnx',
input_names=["input"],
output_names=["output"],
dynamic_axes={'input': {0: 'batch'}, 'output': {0: 'batch'}}
)
print("Done")
如何把上面的代码转换成onnx? 我的模型前向部分的输入是一个元组,不能转换成onnx格式? 谢谢! 我的模型前向部分的输入是一个元组,按照现有的方法是无法转换成onnx格式的。你能告诉我怎么解决吗
正在查看this issue and this other issue, the parameters are unpacked by default so you need to provide a tuple as argument to torch.onnx.export
:
torch.onnx.export(model,
args=(x,),
f='model.onnx',
input_names=["input"],
output_names=["output"],
dynamic_axes={'input': {0: 'batch'}, 'output': {0: 'batch'}})