计算多变量高斯变化均值的有效方法 - Matlab

Efficient way of computing multivariate gaussian varying the mean - Matlab

是否有一种有效的方法来计算 returns 矩阵 p 的多元高斯分布(如下所示),即利用某种矢量化?我知道矩阵 p 是对称的,但对于大小为 40000x3 的矩阵,例如,这将需要相当长的时间。

Matlab 代码示例:

DataMatrix = [3 1 4; 1 2 3; 1 5 7; 3 4 7; 5 5 1; 2 3 1; 4 4 4];

[rows, cols ] = size(DataMatrix);

I = eye(cols);
p = zeros(rows);

for k = 1:rows

    p(k,:) = mvnpdf(DataMatrix(:,:),DataMatrix(k,:),I);

end

第 1 阶段:破解源代码

我们正在迭代执行 mvnpdf(DataMatrix(:,:),DataMatrix(k,:),I)

语法是:mvnpdf(X,Mu,Sigma).

这样,我们输入的对应关系就变成了:

X = DataMatrix(:,:);
Mu = DataMatrix(k,:);
Sigma = I

对于与我们的情况相关的尺寸,源代码mvnpdf.m减少为-

%// Store size parameters of X
[n,d] = size(X);

%// Get vector mean, and use it to center data
X0 = bsxfun(@minus,X,Mu);

%// Make sure Sigma is a valid covariance matrix
[R,err] = cholcov(Sigma,0);

%// Create array of standardized data, and compute log(sqrt(det(Sigma)))
xRinv = X0 / R;
logSqrtDetSigma = sum(log(diag(R)));

%// Finally get the quadratic form and thus, the final output
quadform = sum(xRinv.^2, 2);
p_out = exp(-0.5*quadform - logSqrtDetSigma - d*log(2*pi)/2)

现在,如果 Sigma 始终是单位矩阵,我们也会将 R 作为单位矩阵。因此,X0 / RX0 相同,保存为 xRinv。所以,本质上 quadform = sum(X0.^2, 2);

因此,原代码-

for k = 1:rows
    p(k,:) = mvnpdf(DataMatrix(:,:),DataMatrix(k,:),I);
end

减少到 -

[n,d] = size(DataMatrix);
[R,err] = cholcov(I,0);
p_out = zeros(rows);
K = sum(log(diag(R))) + d*log(2*pi)/2;
for k = 1:rows  
    X0 = bsxfun(@minus,DataMatrix,DataMatrix(k,:));     
    quadform = sum(X0.^2, 2);
    p_out(k,:) = exp(-0.5*quadform - K);
end

现在,如果输入矩阵的大小为 40000x3,您可能想在此处停止。但是在系统资源允许的情况下,您可以将所有内容矢量化,如下所述。

第 2 阶段:向量化一切

既然我们看到了实际发生的事情并且计算看起来是可并行的,是时候 step-up 使用 bsxfun in 3D with his good friend permute 作为矢量化解决方案了,就像这样 -

%// Get size params and R
[n,d] = size(DataMatrix);
[R,err] = cholcov(I,0);

%// Calculate constants : "logSqrtDetSigma" and  "d*log(2*pi)/2`"
K1 = sum(log(diag(R)));
K2 = d*log(2*pi)/2;

%// Major thing happening here as we calclate "X0" for all iterations
%// in one go with permute and bsxfun
diffs = bsxfun(@minus,DataMatrix,permute(DataMatrix,[3 2 1]));

%// "Sigma" is an identity matrix, so it plays no in "/R" at "xRinv = X0 / R".
%// Perform elementwise squaring and summing rows to get vectorized "quadform"
quadform1 = squeeze(sum(diffs.^2,2))

%// Finally use "quadform1" and get vectorized output as a 2D array
p_out = exp(-0.5*quadform1 - K1 - K2)