在 pytorch 中执行卷积(不是互相关)

Performing Convolution (NOT cross-correlation) in pytorch

我有 a network 我正试图在 pytorch 中实现,但我似乎无法弄清楚如何实现“纯”卷积。在tensorflow中可以这样实现:

def conv2d_flipkernel(x, k, name=None):
    return tf.nn.conv2d(x, flipkernel(k), name=name,
                        strides=(1, 1, 1, 1), padding='SAME')

flipkernel函数为:

def flipkernel(kern):
      return kern[(slice(None, None, -1),) * 2 + (slice(None), slice(None))]

如何在 pytorch 中完成类似的事情?

TLDR 使用函数工具箱中的卷积,torch.nn.fuctional.conv2d, not torch.nn.Conv2d,并围绕垂直轴和水平轴翻转过滤器。


torch.nn.Conv2d 是网络的卷积层。因为权重是学习的,所以它是否使用互相关来实现并不重要,因为网络将简单地学习内核的镜像版本(感谢@etarion 的澄清)。

torch.nn.fuctional.conv2d 使用作为参数提供的输入和权重执行卷积,类似于示例中的 tensorflow 函数。我写了一个简单的测试来确定它是否像 tensorflow 函数一样,实际上是在执行互相关,并且有必要翻转过滤器以获得正确的卷积结果。

import torch
import torch.nn.functional as F
import torch.autograd as autograd
import numpy as np

#A vertical edge detection filter. 
#Because this filter is not symmetric, for correct convolution the filter must be flipped before element-wise multiplication
filters = autograd.Variable(torch.FloatTensor([[[[-1, 1]]]]))

#A test image of a square
inputs = autograd.Variable(torch.FloatTensor([[[[0,0,0,0,0,0,0], [0, 0, 1, 1, 1, 0, 0], 
                                             [0, 0, 1, 1, 1, 0, 0], [0, 0, 1, 1, 1, 0, 0],
                                            [0,0,0,0,0,0,0]]]]))
print(F.conv2d(inputs, filters))

这输出

Variable containing:
(0 ,0 ,.,.) = 
  0  0  0  0  0  0
  0  1  0  0 -1  0
  0  1  0  0 -1  0
  0  1  0  0 -1  0
  0  0  0  0  0  0
[torch.FloatTensor of size 1x1x5x6]

此输出是互相关的结果。因此,我们需要翻转过滤器

def flip_tensor(t):
    flipped = t.numpy().copy()

    for i in range(len(filters.size())):
        flipped = np.flip(flipped,i) #Reverse given tensor on dimention i
    return torch.from_numpy(flipped.copy())

print(F.conv2d(inputs, autograd.Variable(flip_tensor(filters.data))))

新的输出是正确的卷积结果。

Variable containing:
(0 ,0 ,.,.) = 
  0  0  0  0  0  0
  0 -1  0  0  1  0
  0 -1  0  0  1  0
  0 -1  0  0  1  0
  0  0  0  0  0  0
[torch.FloatTensor of size 1x1x5x6] 

与上面的答案没什么不同,但是 torch 可以在本地执行 flip(i)(我猜你只想 flip(2)flip(3)):

def convolution(A, B):
  return F.conv2d(A, B.flip(2).flip(3), padding=padding)