比较 MatLab 的 conv2 和 scipy 的 convolve2d

Comparing MatLab's conv2 with scipy's convolve2d

我正在尝试使用非对称权重计算 S3x3 移动平均线,如本 MatLab example 中所述,我不确定从 MatLab 翻译时我对以下内容的解释是否正确:

  1. 我是否以同样的方式设置了我的矩阵?
  2. 在这种情况下scipy.signal.convolve2d do the same as MatLab's conv2d吗?
  3. 为什么我的身材这么差?!

在 MatLab 中,滤波器的给出和应用如下:

% S3x3 seasonal filter
% Symmetric weights
sW3 = [1/9;2/9;1/3;2/9;1/9];
% Asymmetric weights for end of series
aW3 = [.259 .407;.37 .407;.259 .185;.111 0];

% dat contains data - simplified adaptation from link above
ns = length(dat) ; first = 1:4 ; last = ns - 3:ns; 
trend = conv(dat, sW3, 'same');
trend(1:2) = conv2(dat(first), 1, rot90(aW3,2), 'valid');
trend(ns-1:ns) = conv2(dat(last), 1, aW3, 'valid');

我在 python 中使用我自己的数据对此进行了解释,我在这样做时假设 MatLab 矩阵中的 ; 意味着 新行 并且a space 表示 新列

import numpy as np
from scipy.signal import convolve2d

dat = np.array([0.02360784,  0.0227628 ,  0.0386366 ,  0.03338596,  0.03141621, 0.03430469])
dat = dat.reshape(dat.shape[0], 1) # in columns

sW3 = np.array([[1/9.],[2/9.],[1/3.],[2/9.],[1/9.]])
aW3 = np.array( [[ 0.259,  0.407],
                 [ 0.37 ,  0.407],
                 [ 0.259,  0.185],
                 [ 0.111,  0.   ]])

trend = convolve2d(dat, sW3, 'same')
trend[:2] = convolve2d(dat[:2], np.rot90(aW3,2), 'same')
trend[-2:] = convolve2d(dat[-2:], np.rot90(aW3,2), 'same')

绘制数据,拟合度很差...

import matplotlib.pyplot as plt
plt.plot(dat, 'grey', label='raw data', linewidth=4.)
plt.plot(trend, 'b--', label = 'S3x3 trend')
plt.legend()
plt.plot()

已解决。

事实证明,这个问题确实与 MatLab 的 conv2d 和 scipy 的 convolve2d 中的细微差别有关,来自 docs:

C = conv2(h1,h2,A) first convolves each column of A with the vector h1 and then convolves each row of the result with the vector h2

这意味着在我的代码中边缘实际发生的事情是不正确的。它实际上应该是sum每个端点的卷积和aW3

的相关列

例如

nwtrend = np.zeros(dat.shape)
nwtrend = convolve2d(dat, sW3, 'same')

for i in xrange(np.rot90(aW3, 2).shape[1]):
    nwtrend[i] = np.convolve(dat[i,0], np.rot90(aW3, 2)[:,i], 'same').sum()

for i in xrange(aW3.shape[1]):
    nwtrend[-i-1] = np.convolve(dat[-i-1,0], aW3[:,i], 'same').sum()

得到上面的输出

plt.plot(nwtrend, 'r', label='S3x3 new',linewidth=2.)
plt.plot(trend, 'b--', label='S3x3 old')
plt.legend(loc='lower centre')
plt.show()