当我们 return ReLU(x) 时 inplace 重要吗

Does inplace matter when we return ReLU(x)

下面两个类有区别吗? 我知道 inplace 是什么(如果 inplace 是 True,你不需要做 x = function(x) 但只需要 function(x) 来修改 x。不过这里因为我们return self.conv(x),应该没关系吧?

class ConvBlock(nn.Module):

    def __init__(
        self,
        in_channels,
        out_channels,
        down=True,
        use_act=True,
        **kwargs
        ):
        super().__init__()
        self.conv = nn.Sequential((nn.Conv2d(in_channels, out_channels,
                                  padding_mode='reflect',
                                  **kwargs) if down else nn.ConvTranspose2d(in_channels,
                                  out_channels, **kwargs)),
                                  nn.InstanceNorm2d(out_channels),
                                  (nn.ReLU() if use_act else nn.Identity()))

    def forward(self, x):
        return self.conv(x)


class ConvBlockInplace(nn.Module):

    def __init__(
        self,
        in_channels,
        out_channels,
        down=True,
        use_act=True,
        **kwargs
        ):
        super().__init__()
        self.conv = nn.Sequential((nn.Conv2d(in_channels, out_channels,
                                  padding_mode='reflect',
                                  **kwargs) if down else nn.ConvTranspose2d(in_channels,
                                  out_channels, **kwargs)),
                                  nn.InstanceNorm2d(out_channels),
                                  (nn.ReLU(inplace=True) if use_act else nn.Identity()))

    def forward(self, x):
        return self.conv(x)

inplace 操作执行精确的计算量。但是,如果您的任务受内存限制,内存访问就会减少。然后,它会“重要”。


我使用 the ptflops flops counter 生成了以下统计数据

ConvBlock(
  0.0 M, 100.000% Params, 0.015 GMac, 100.000% MACs, 
  (conv): Sequential(
    0.0 M, 100.000% Params, 0.015 GMac, 100.000% MACs, 
    (0): Conv2d(0.0 M, 100.000% Params, 0.014 GMac, 93.333% MACs, 3, 10, kernel_size=(3, 3), stride=(1, 1), padding_mode=reflect)
    (1): InstanceNorm2d(0.0 M, 0.000% Params, 0.0 GMac, 3.333% MACs, 10, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
    (2): ReLU(0.0 M, 0.000% Params, 0.0 GMac, 3.333% MACs, )
  )
)
Computational complexity:       0.01 GMac
Number of parameters:           280     
Warning: module ConvBlockInplace is treated as a zero-op.
ConvBlockInplace(
  0.0 M, 100.000% Params, 0.015 GMac, 100.000% MACs, 
  (conv): Sequential(
    0.0 M, 100.000% Params, 0.015 GMac, 100.000% MACs, 
    (0): Conv2d(0.0 M, 100.000% Params, 0.014 GMac, 93.333% MACs, 3, 10, kernel_size=(3, 3), stride=(1, 1), padding_mode=reflect)
    (1): InstanceNorm2d(0.0 M, 0.000% Params, 0.0 GMac, 3.333% MACs, 10, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
    (2): ReLU(0.0 M, 0.000% Params, 0.0 GMac, 3.333% MACs, inplace=True)
  )
)
Computational complexity:       0.01 GMac
Number of parameters:           280