将火炬模型导出为 onnx 格式时出现问题
issue while exporting torch model to onnx format
我正在尝试将我的 PyTorch 模型导出为 ONNX 格式,但我一直收到此错误:
TypeError: forward() missing 1 required positional argument: 'text'
这是我的代码:
model = Model(opt)
dummy_input = torch.randn(1, 3, 224, 224)
file_path='/content/drive/MyDrive/VitSTR/vitstr_tiny_patch16_224_aug.pth'
torch.save(model.state_dict(), file_path)
model.load_state_dict(torch.load(file_path))
#model = torch.nn.DataParallel(model).to(device)
#print(model)
torch.onnx.export(model, dummy_input, "vitstr.onnx", verbose=True)
ViTSTR forward 需要两个位置参数,input
和 text
:
def forward(self, input, text, is_train=True, seqlen=25):
# ...
因此,您需要传递一个额外的参数:
# ...
dummy_text = # create a dummy_text as well, with the appropriate shape
torch.onnx.export(model, (dummy_input, dummy_text), "vitstr.onnx", verbose=True)
我正在尝试将我的 PyTorch 模型导出为 ONNX 格式,但我一直收到此错误:
TypeError: forward() missing 1 required positional argument: 'text'
这是我的代码:
model = Model(opt)
dummy_input = torch.randn(1, 3, 224, 224)
file_path='/content/drive/MyDrive/VitSTR/vitstr_tiny_patch16_224_aug.pth'
torch.save(model.state_dict(), file_path)
model.load_state_dict(torch.load(file_path))
#model = torch.nn.DataParallel(model).to(device)
#print(model)
torch.onnx.export(model, dummy_input, "vitstr.onnx", verbose=True)
ViTSTR forward 需要两个位置参数,input
和 text
:
def forward(self, input, text, is_train=True, seqlen=25):
# ...
因此,您需要传递一个额外的参数:
# ...
dummy_text = # create a dummy_text as well, with the appropriate shape
torch.onnx.export(model, (dummy_input, dummy_text), "vitstr.onnx", verbose=True)