计算感知损失的 VGG 特征的正确方法
Correct way to compute VGG features for Perceptual loss
在计算VGG Perceptual loss的时候,虽然没见过,但是感觉把GT图像的VGG特征的计算包在torch.no_grad()
里面就好了。
所以基本上我觉得下面就可以了,
with torch.no_grad():
gt_vgg_features = self.vgg_features(gt)
nw_op_vgg_features = self.vgg_features(nw_op)
# Now compute L1 loss
或者应该使用,
gt_vgg_features = self.vgg_features(gt)
nw_op_vgg_features = self.vgg_features(nw_op)
在这两种方法中,requires_grad
设置了 VGG 参数 False
并且 VGG 处于 eval()
模式。
第一种方法将节省大量 GPU 资源,感觉在数值上应该与第二种方法相同,因为不需要通过 GT 图像进行反向传播。但在大多数实现中,我发现第二种方法用于计算 VGG 感知损失。
那么我们应该选择哪个选项来在 PyTorch 中实现 VGG 感知损失?
第一种方式:
with torch.no_grad():
gt_vgg_features = self.vgg_features(gt)
nw_op_vgg_features = self.vgg_features(nw_op)
尽管 VGG 处于 eval
模式并且其参数保持固定,您仍然需要通过它传播梯度,从特征损失到输出 nw_op
。
但是,没有理由计算这些梯度 w.r.t gt
.
在计算VGG Perceptual loss的时候,虽然没见过,但是感觉把GT图像的VGG特征的计算包在torch.no_grad()
里面就好了。
所以基本上我觉得下面就可以了,
with torch.no_grad():
gt_vgg_features = self.vgg_features(gt)
nw_op_vgg_features = self.vgg_features(nw_op)
# Now compute L1 loss
或者应该使用,
gt_vgg_features = self.vgg_features(gt)
nw_op_vgg_features = self.vgg_features(nw_op)
在这两种方法中,requires_grad
设置了 VGG 参数 False
并且 VGG 处于 eval()
模式。
第一种方法将节省大量 GPU 资源,感觉在数值上应该与第二种方法相同,因为不需要通过 GT 图像进行反向传播。但在大多数实现中,我发现第二种方法用于计算 VGG 感知损失。
那么我们应该选择哪个选项来在 PyTorch 中实现 VGG 感知损失?
第一种方式:
with torch.no_grad():
gt_vgg_features = self.vgg_features(gt)
nw_op_vgg_features = self.vgg_features(nw_op)
尽管 VGG 处于 eval
模式并且其参数保持固定,您仍然需要通过它传播梯度,从特征损失到输出 nw_op
。
但是,没有理由计算这些梯度 w.r.t gt
.