如何在 pytorch 脚本模型中获取运算符的输入大小?

How to get input size for a operator in pytorch script model?

我使用此代码将模型转换为脚本模型:

scripted_model = torch.jit.trace(detector.model, images).eval()

然后我打印 scripted_model。部分输出如下:

 (base): DLA(
    original_name=DLA
    (base_layer): Sequential(
      original_name=Sequential
      (0): Conv2d(original_name=Conv2d)
      (1): BatchNorm2d(original_name=BatchNorm2d)
      (2): ReLU(original_name=ReLU)
    )
    (level0): Sequential(
      original_name=Sequential
      (0): Conv2d(original_name=Conv2d)
      (1): BatchNorm2d(original_name=BatchNorm2d)
      (2): ReLU(original_name=ReLU)
    )
    (level1): Sequential(
      original_name=Sequential
      (0): Conv2d(original_name=Conv2d)
      (1): BatchNorm2d(original_name=BatchNorm2d)
      (2): ReLU(original_name=ReLU)
    )
    (level2): Tree(
      original_name=Tree
      (tree1): BasicBlock(
        original_name=BasicBlock
        (conv1): Conv2d(original_name=Conv2d)
        (bn1): BatchNorm2d(original_name=BatchNorm2d)
        (relu): ReLU(original_name=ReLU)
        (conv2): Conv2d(original_name=Conv2d)
        (bn2): BatchNorm2d(original_name=BatchNorm2d)
      )
      (tree2): BasicBlock(
        original_name=BasicBlock
        (conv1): Conv2d(original_name=Conv2d)
        (bn1): BatchNorm2d(original_name=BatchNorm2d)
        (relu): ReLU(original_name=ReLU)
        (conv2): Conv2d(original_name=Conv2d)
        (bn2): BatchNorm2d(original_name=BatchNorm2d
      )
      (root): Root(
        original_name=Root
        (conv): Conv2d(original_name=Conv2d)
        (bn): BatchNorm2d(original_name=BatchNorm2d)
        (relu): ReLU(original_name=ReLU)
      )
      (downsample): MaxPool2d(original_name=MaxPool2d)
      (project): Sequential(
        original_name=Sequential
        (0): Conv2d(original_name=Conv2d)
        (1): BatchNorm2d(original_name=BatchNorm2d)
      )
    ) 
...

我只想获取算子的输入大小,比如算子的输入有多少(0): Conv2d(original_name=Conv2d)。我打印这个脚本模型的图形,输出结果如下:

  %4770 : __torch__.torch.nn.modules.module.___torch_mangle_11.Module = prim::GetAttr[name="wh"](%self.1)
  %4762 : __torch__.torch.nn.modules.module.___torch_mangle_15.Module = prim::GetAttr[name="tracking"](%self.1)
  %4754 : __torch__.torch.nn.modules.module.___torch_mangle_23.Module = prim::GetAttr[name="rot"](%self.1)
  %4746 : __torch__.torch.nn.modules.module.___torch_mangle_7.Module = prim::GetAttr[name="reg"](%self.1)
  %4738 : __torch__.torch.nn.modules.module.___torch_mangle_3.Module = prim::GetAttr[name="hm"](%self.1)
  %4730 : __torch__.torch.nn.modules.module.___torch_mangle_27.Module = prim::GetAttr[name="dim"](%self.1)
  %4722 : __torch__.torch.nn.modules.module.___torch_mangle_19.Module = prim::GetAttr[name="dep"](%self.1)
  %4714 : __torch__.torch.nn.modules.module.___torch_mangle_31.Module = prim::GetAttr[name="amodel_offset"](%self.1)
  %4706 : __torch__.torch.nn.modules.module.___torch_mangle_289.Module = prim::GetAttr[name="ida_up"](%self.1)
  %4645 : __torch__.torch.nn.modules.module.___torch_mangle_262.Module = prim::GetAttr[name="dla_up"](%self.1)
  %4461 : __torch__.torch.nn.modules.module.___torch_mangle_180.Module = prim::GetAttr[name="base"](%self.1)
  %5100 : (Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor) = prim::CallMethod[name="forward"](%4461, %input.1)
  %5082 : Tensor, %5083 : Tensor, %5084 : Tensor, %5085 : Tensor, %5086 : Tensor, %5087 : Tensor, %5088 : Tensor, %5089 : Tensor = prim::TupleUnpack(%5100)
  %5101 : (Tensor, Tensor, Tensor) = prim::CallMethod[name="forward"](%4645, %5082, %5083, %5084, %5085, %5086, %5087, %5088, %5089)
  %5097 : Tensor, %5098 : Tensor, %5099 : Tensor = prim::TupleUnpack(%5101)
  %3158 : None = prim::Constant()

我什至可以找到运营商名称。如何获取脚本模型中特定运算符的输入大小?

一个解决方案是从torchinfo尝试summary,第一层的输出形状是下一层的输入形状,依此类推:

!pip install torchinfo

from torchinfo import summary
summary(model, input_size=(batch_size, 3, 224, 224)) # input size to your NN

#output 

===============================================================================================
    Layer (type:depth-idx)                        Output Shape              Param #
    ===============================================================================================
    ResNet50                                      --                        --
    ├─ResNet: 1-1                                 [64, 10]                  --
    │    └─Conv2d: 2-1                            [64, 64, 112, 112]        9,408
    │    └─BatchNorm2d: 2-2                       [64, 64, 112, 112]        128
    │    └─ReLU: 2-3                              [64, 64, 112, 112]        --
    │    └─MaxPool2d: 2-4                         [64, 64, 56, 56]          --
    │    └─Sequential: 2-5                        [64, 64, 56, 56]          --
    │    │    └─BasicBlock: 3-1                   [64, 64, 56, 56]          73,984
    │    │    └─BasicBlock: 3-2                   [64, 64, 56, 56]          73,984
    │    └─Sequential: 2-6                        [64, 128, 28, 28]         --
    │    │    └─BasicBlock: 3-3                   [64, 128, 28, 28]         230,144
    │    │    └─BasicBlock: 3-4                   [64, 128, 28, 28]         295,424
    │    └─Sequential: 2-7                        [64, 256, 14, 14]         --
    │    │    └─BasicBlock: 3-5                   [64, 256, 14, 14]         919,040
    │    │    └─BasicBlock: 3-6                   [64, 256, 14, 14]         1,180,672
    │    └─Sequential: 2-8                        [64, 512, 7, 7]           --
    │    │    └─BasicBlock: 3-7                   [64, 512, 7, 7]           3,673,088
    │    │    └─BasicBlock: 3-8                   [64, 512, 7, 7]           4,720,640
    │    └─AdaptiveAvgPool2d: 2-9                 [64, 512, 1, 1]           --
    │    └─Linear: 2-10                           [64, 10]                  5,130
    ===============================================================================================
    Total params: 11,181,642
    Trainable params: 11,181,642
    Non-trainable params: 0
    Total mult-adds (G): 116.07
    ===============================================================================================
    Input size (MB): 38.54
    Forward/backward pass size (MB): 2543.33
    Params size (MB): 44.73
    Estimated Total Size (MB): 2626.59
    ===============================================================================================