了解 MATLAB convn 的行为

Understanding behaviour of MATLAB's convn

我正在做一些张量的卷积。

这是 MATLAB 中的小测试:

    ker= rand(3,4,2);
    a= rand(5,7,2);
    c=convn(a,ker,'valid');
    c11=sum(sum(a(1:3,1:4,1).*ker(:,:,1)))+sum(sum(a(1:3,1:4,2).*ker(:,:,2)));
    c(1,1)-c11  % not equal!

第三行与convn进行N维卷积,我想将convn第一行第一列的结果与手动计算值进行比较。但是,我的计算与 convn 相比并不相等。

那么MATLAB的convn背后是什么?我对张量卷积的理解是错误的吗?

是的,你对卷积的理解是错误的。您的 c11 公式不是卷积:您只是将匹配的索引相乘然后求和。它更像是一种点积运算(在修剪为相同大小的张量上)。我将从一维开始解释。

一维数组

正在输入 conv([4 5 6], [2 3]) returns [8 22 27 18]。我发现从多项式的乘法来看最容易想到这一点:

(4+5x+6x^2)*(2+3x) = 8+22x+27x^2+18x^3

将每个数组的条目用作多项式的系数,将多项式相乘,收集相似的项,并从系数中读取结果。 x 的幂在这里用于跟踪乘法和加法。请注意,x^n 的系数在第 (n+1) 个条目中找到,因为 x 的幂以 0 开头,而索引以 1 开头。

二维数组

输入conv2([2 3; 3 1], [4 5 6; 0 -1 1])returns矩阵

 8  22  27  18
12  17  22   9
 0  -3   2   1

同样,这可以解释为多项式的乘法,但现在我们需要两个变量:比如 x 和 y。 x^n y^m 的系数在 (m+1, n+1) 条目中找到。上面的输出意味着

(2+3x+3y+xy)*(4+5x+6x^2+0y-xy+x^2y) = 8+22x+27x^2+18x^3+12y+17xy+22x ^2y+9x^3y-3xy^2+2x^2y^2+x^3y^2

三维数组

同样的故事。您可以将条目视为变量 x、y、z 中多项式的系数。多项式相乘,乘积的系数就是卷积的结果。

'valid'参数

这只保留了卷积的中心部分:第二个因子的所有项都参与的那些系数。要使其非空,第二个数组的维度不应大于第一个数组。 (这与默认设置不同,默认设置中卷积数组的顺序无关紧要。)示例:

conv([4 5 6], [2 3]) returns [22 27](与上面的一维示例相比)。这对应于

中的事实

(4+5x+6x^2)*(2+3x) = 8+22x+27x^2+18x ^3

粗体字词来自 2 和 3x。

几乎是对的。您的理解有两点稍微不对:

  1. 您选择 valid 作为卷积标志。这意味着从卷积返回的输出有其大小,因此当您使用内核扫过矩阵时,它必须舒适地适合矩阵本身。因此,返回的第一个 "valid" 输出实际上是用于矩阵位置 (2,2,1) 处的计算。这意味着您可以在这个位置舒适地安装您的内核,这对应于输出的位置 (1,1)。为了演示,这就是 aker 对我来说使用上面的代码的样子:

    >> a
    
    a(:,:,1) =
    
    0.9930    0.2325    0.0059    0.2932    0.1270    0.8717    0.3560
    0.2365    0.3006    0.3657    0.6321    0.7772    0.7102    0.9298
    0.3743    0.6344    0.5339    0.0262    0.0459    0.9585    0.1488
    0.2140    0.2812    0.1620    0.8876    0.7110    0.4298    0.9400
    0.1054    0.3623    0.5974    0.0161    0.9710    0.8729    0.8327
    
    
    a(:,:,2) =
    
    0.8461    0.0077    0.5400    0.2982    0.9483    0.9275    0.8572
    0.1239    0.0848    0.5681    0.4186    0.5560    0.1984    0.0266
    0.5965    0.2255    0.2255    0.4531    0.5006    0.0521    0.9201
    0.0164    0.8751    0.5721    0.9324    0.0035    0.4068    0.6809
    0.7212    0.3636    0.6610    0.5875    0.4809    0.3724    0.9042
    
    >> ker
    
    ker(:,:,1) =
    
    0.5395    0.4849    0.0970    0.3418
    0.6263    0.9883    0.4619    0.7989
    0.0055    0.3752    0.9630    0.7988
    
    
    ker(:,:,2) =
    
    0.2082    0.4105    0.6508    0.2669
    0.4434    0.1910    0.8655    0.5021
    0.7156    0.9675    0.0252    0.0674
    

    如您所见,在矩阵 a 中的位置 (2,2,1)ker 可以很容易地放入矩阵中,如果您回忆起卷积,它只是一个总和内核与位置 (2,2,1) 处的矩阵子集之间的逐元素乘积,其大小与您的内核相同(实际上,您需要对内核做一些其他事情,我将在下一点中保留) - 见下文)。因此,您正在计算的系数实际上是 (2,2,1) 处的输出,而不是 (1,1,1) 处的输出。虽然从它的要点来看,你已经知道了这一点,但我想把它放在那里以防你不知道。

  2. 您忘记了对于 N 维卷积,您需要在每个维度上翻转掩码。如果您还记得一维卷积,必须水平翻转掩码。我所说的翻转的意思是您只需将元素按相反的顺序放置即可。例如,[1 2 3 4] 的数组将变为 [4 3 2 1]。在 2D 卷积中,你必须水平和垂直翻转。因此,您将获取矩阵的每一行并以相反的顺序放置每一行,就像一维情况一样。在这里,您会将每一行视为一维信号并进行翻转。完成此操作后,您将采用此翻转结果,并将每个 视为一维信号并再次进行翻转。

    现在,对于 3D,您必须水平、垂直 和时间 翻转。这意味着您需要对矩阵的每个切片独立执行 2D 翻转,然后您将以 3D 方式获取单个列并将其视为 1D 信号。在 MATLAB 语法中,您会得到 ker(1,1,:),将其视为一维信号,然后翻转。您将对 ker(1,2,:)ker(1,3,:) 等重复此操作,直到您完成第一个切片。请记住,我们不会转到第二个切片或任何其他切片并重复我们刚刚做的事情。因为您正在获取矩阵的 3D 部分,所以您本质上是在对您提取的每个 3D 列的所有切片进行操作。因此,只查看矩阵的第一个切片,因此您需要在计算卷积之前对内核执行此操作:

    ker_flipped = flipdim(flipdim(flipdim(ker, 1), 2), 3);
    

    flipdim 在指定轴上执行翻转。在我们的例子中,我们是垂直进行的,然后获取结果并水平进行,然后再次进行临时处理。然后,您将在求和中使用 ker_flipped。请注意,翻转的顺序无关紧要。 flipdim对每个维度独立操作,所以只要你记得翻转所有维度,输出都是一样的。


为了演示,下面是 convn:

的输出结果
c =

    4.1837    4.1843    5.1187    6.1535
    4.5262    5.3253    5.5181    5.8375
    5.1311    4.7648    5.3608    7.1241

现在,要手动确定 c(1,1) 是什么,您需要在 flipped 内核上进行计算:

ker_flipped = flipdim(flipdim(flipdim(ker, 1), 2), 3);
c11 = sum(sum(a(1:3,1:4,1).*ker_flipped(:,:,1)))+sum(sum(a(1:3,1:4,2).*ker_flipped(:,:,2)));

我们得到的输出是:

c11 =

    4.1837

如您所见,这验证了我们使用 convn 在 MATLAB 中通过手动计算得到的结果。如果你想比较更多的精度数字,使用 format long 并比较它们:

>> format long;
>> disp(c11)

   4.183698205668000

>> disp(c(1,1))

   4.183698205668001

如您所见,所有数字都相同,除了最后一位。这归因于数字四舍五入。绝对确定:

>> disp(abs(c11 - c(1,1)));

   8.881784197001252e-16

...我觉得相差10-16就足以让我证明他们是平等的,对吧?