计算感知损失的 VGG 特征的正确方法

Correct way to compute VGG features for Perceptual loss

在计算VGG Perceptual loss的时候,虽然没见过,但是感觉把GT图像的VGG特征的计算包在torch.no_grad()里面就好了。

所以基本上我觉得下面就可以了,

with torch.no_grad():
    gt_vgg_features = self.vgg_features(gt)

nw_op_vgg_features = self.vgg_features(nw_op)

# Now compute L1 loss

或者应该使用,

gt_vgg_features = self.vgg_features(gt)
nw_op_vgg_features = self.vgg_features(nw_op)

在这两种方法中,requires_grad 设置了 VGG 参数 False 并且 VGG 处于 eval() 模式。

第一种方法将节省大量 GPU 资源,感觉在数值上应该与第二种方法相同,因为不需要通过 GT 图像进行反向传播。但在大多数实现中,我发现第二种方法用于计算 VGG 感知损失。

那么我们应该选择哪个选项来在 PyTorch 中实现 VGG 感知损失?

第一种方式:

with torch.no_grad():
    gt_vgg_features = self.vgg_features(gt)

nw_op_vgg_features = self.vgg_features(nw_op)

尽管 VGG 处于 eval 模式并且其参数保持固定,您仍然需要通过它传播梯度,从特征损失到输出 nw_op。 但是,没有理由计算这些梯度 w.r.t gt.