在 TorchScript 中跟踪张量大小

Tracing Tensor Sizes in TorchScript

我正在通过 TorchScript 跟踪导出 PyTorch 模型,但我遇到了问题。具体来说,我必须对张量大小执行一些操作,但 JIT 编译器将变量形状硬编码为常量,从而破坏了与不同大小张量的兼容性。

例如,创建 class:

class Foo(nn.Module):
    """Toy class that plays with tensor shape to showcase tracing issue.

    It creates a new tensor with the same shape as the input one, except
    for the last dimension, which is doubled. This new tensor is filled
    based on the values of the input.
    """
    def __init__(self):
        nn.Module.__init__(self)

    def forward(self, x):
        new_shape = (x.shape[0], 2*x.shape[1])  # incriminated instruction
        x2 = torch.empty(size=new_shape)
        x2[:, ::2] = x
        x2[:, 1::2] = x + 1
        return x2

和运行测试代码:

x = torch.randn((3, 5))  # create example input

foo = Foo()
traced_foo = torch.jit.trace(foo, x)  # trace
print(traced_foo(x).shape)  # obviously this works
print(traced_foo(x[:, :4]).shape)  # but fails with a different shape!

我可以通过脚本来解决问题,但在这种情况下我真的需要使用跟踪。此外,我认为跟踪应该能够正确处理张量大小操作。

but in this case I really need to use tracing

您可以根据需要自由混合 torch.scripttorch.jit。例如可以这样做:

import torch


class MySuperModel(torch.nn.Module):
    def __init__(self, *args, **kwargs):
        super().__init__()
        self.scripted = torch.jit.script(Foo(*args, **kwargs))
        self.traced = Bar(*args, **kwargs)

    def forward(self, data):
        return self.scripted(self.traced(data))

model = MySuperModel()
torch.jit.trace(model, (input1, input2))

你也可以将依赖形状的部分功能移动到分离功能并用@torch.jit.script:

装饰它
@torch.jit.script
def _forward_impl(x):
    new_shape = (x.shape[0], 2*x.shape[1])  # incriminated instruction
    x2 = torch.empty(size=new_shape)
    x2[:, ::2] = x
    x2[:, 1::2] = x + 1
    return x2

class Foo(nn.Module):
    def forward(self, x):
        return _forward_impl(x)

除了 script 别无他法,因为它必须理解您的代码。通过跟踪,它仅记录您对张量执行的操作,并且不知道依赖于数据(或数据形状)的控制流。

无论如何,这应该涵盖了大多数情况,如果没有,您应该更具体一些。

此错误已在 1.10.2 中修复