Torchscript 与张量列表的 torch.cat 不兼容
Torchscript incompatible with torch.cat for tensor lists
Torch.cat 在 torchscript
中使用时会抛出张量列表错误
这是重现错误的最小可重现示例
import torch
import torch.nn as nn
"""
Smallest working bug for torch.cat torchscript
"""
class Model(nn.Module):
"""dummy model for showing error"""
def __init__(self):
super(Model, self).__init__()
pass
def forward(self):
a = torch.rand([6, 1, 12])
b = torch.rand([6, 1, 12])
out = torch.cat([a, b], axis=2)
return out
if __name__ == '__main__':
model = Model()
print(model()) # works
torch.jit.script(model) # throws error
预期结果将是 torch.cat 的 torchscript 输出。这是提供的错误消息:
File "/home/anil/.conda/envs/rnn/lib/python3.7/site-packages/torch/jit/__init__.py", line 1423, in _create_methods_from_stubs
self._c._create_methods(self, defs, rcbs, defaults)
RuntimeError:
Arguments for call are not valid.
The following operator variants are available:
aten::cat(Tensor[] tensors, int dim=0) -> (Tensor):
Keyword argument axis unknown.
aten::cat.out(Tensor[] tensors, int dim=0, *, Tensor(a!) out) -> (Tensor(a!)):
Argument out not provided.
The original call is:
at smallest_working_bug_torch_cat_torchscript.py:19:14
def forward(self):
a = torch.rand([6, 1, 12])
b = torch.rand([6, 1, 12])
out = torch.cat([a, b], axis=2)
~~~~~~~~~ <--- HERE
return out
请让我知道此问题的修复或解决方法。
谢谢!
将 axis
更改为 dim
修复了错误,
原始解决方案已发布 here
Torch.cat 在 torchscript
中使用时会抛出张量列表错误这是重现错误的最小可重现示例
import torch
import torch.nn as nn
"""
Smallest working bug for torch.cat torchscript
"""
class Model(nn.Module):
"""dummy model for showing error"""
def __init__(self):
super(Model, self).__init__()
pass
def forward(self):
a = torch.rand([6, 1, 12])
b = torch.rand([6, 1, 12])
out = torch.cat([a, b], axis=2)
return out
if __name__ == '__main__':
model = Model()
print(model()) # works
torch.jit.script(model) # throws error
预期结果将是 torch.cat 的 torchscript 输出。这是提供的错误消息:
File "/home/anil/.conda/envs/rnn/lib/python3.7/site-packages/torch/jit/__init__.py", line 1423, in _create_methods_from_stubs
self._c._create_methods(self, defs, rcbs, defaults)
RuntimeError:
Arguments for call are not valid.
The following operator variants are available:
aten::cat(Tensor[] tensors, int dim=0) -> (Tensor):
Keyword argument axis unknown.
aten::cat.out(Tensor[] tensors, int dim=0, *, Tensor(a!) out) -> (Tensor(a!)):
Argument out not provided.
The original call is:
at smallest_working_bug_torch_cat_torchscript.py:19:14
def forward(self):
a = torch.rand([6, 1, 12])
b = torch.rand([6, 1, 12])
out = torch.cat([a, b], axis=2)
~~~~~~~~~ <--- HERE
return out
请让我知道此问题的修复或解决方法。
谢谢!
将 axis
更改为 dim
修复了错误,
原始解决方案已发布 here