计算点集和参考点集之间的马氏距离

Computing Mahalanobis Distance Between Set of Points and Set of Reference Points

我有一个 n x p 矩阵 - mX,它由 R^p 中的 n 个点组成。

我有另一个 m x p 矩阵 - mY 由 R^p 中的 m 个参考点组成。

我想创建一个 n x m 矩阵 - mD,它是 Mahalanobis Distance 矩阵。

D(i, j) 表示 mX, mX(i, :) 中的点 i 和 mY, mY(j, :) 中的点 j 之间的Mahalanobis Distance

即计算如下:

mD(i, j) = (mX(i, :) - mY(j, :)) * inv(mC) * (mX(i, :) - mY(j, :)).';

其中 mC 是给定的马氏距离 PSD 矩阵。

很容易在循环中完成,有没有向量化的方法?

即输入为 mX、mY 和 mC 输出为 mD 且不使用任何 MATLAB 工具箱进行完全矢量化的函数?

谢谢。

我得出结论,向量化这个问题效率不高。我对这个问题进行矢量化的最佳想法是需要 m x n x p x p 工作内存,至少在一次处理所有内容的情况下。这意味着在 n=m=p=152 的情况下,代码已经需要 4GB Ram。在这些维度上,我的系统可以 运行 不到一秒的循环:

mD=zeros(size(mX,1),size(mY,1));
ImC=inv(mC);
for i=1:size(mX,1)
    for j=1:size(mY,1)
        d=mX(i, :) - mY(j, :);
        mD(i, j) = (d) * ImC * (d).';
    end
end

这是一种消除一个循环的解决方案

function d = mahalanobis(mX, mY)

    n = size(mX, 2);
    m = size(mY, 2);
    data = [mX, mY];
    mc = cov(transpose(data));

    dist = zeros(n,m);
    for i = 1 : n
        diff = repmat(mX(:,i), 1, m) - mY;
        dist(i,:) = sum((mc\diff).*diff , 1);
    end
    d = sqrt(dist);

end

您可以将其调用为:

d = mahalanobis(transpose(X),transpose(Y))

方法 #1

假设 无限 资源,这是一个使用 bsxfunmatrix-multiplication -

的矢量化解决方案
A = reshape(bsxfun(@minus,permute(mX,[1 3 2]),permute(mY,[3 1 2])),[],p);
out = reshape(diag(A*inv(mC)*A.'),n,m);

方法 #2

这是一个尝试降低循环复杂度的组合解决方案 -

A = reshape(bsxfun(@minus,permute(mX,[1 3 2]),permute(mY,[3 1 2])),[],p);
imC = inv(mC);
out = zeros(n*m,1);
for ii = 1:n*m
    out(ii) = A(ii,:)*imC*A(ii,:).';
end
out = reshape(out,n,m);

样本运行-

>> n = 3;  m = 4;   p = 5;
mX = rand(n,p);
mY = rand(m,p);
mC = rand(p,p);
imC = inv(mC);
>> %// Original solution
for i = 1:n
    for j = 1:m
        mD(i, j) = (mX(i, :) - mY(j, :)) * inv(mC) * (mX(i, :) - mY(j, :)).'; %//'
    end
end
>> mD
mD =
      -8.4256       10.032       2.8929       7.1762
      -44.748      -4.3851      -13.645      -9.6702
      -4.5297       3.2928      0.11132       2.5998
>> %// Approach #1
A = reshape(bsxfun(@minus,permute(mX,[1 3 2]),permute(mY,[3 1 2])),[],p);
out = reshape(diag(A*inv(mC)*A.'),n,m);  %//'
>> out
out =
      -8.4256       10.032       2.8929       7.1762
      -44.748      -4.3851      -13.645      -9.6702
      -4.5297       3.2928      0.11132       2.5998
>> %// Approach #2
A = reshape(bsxfun(@minus,permute(mX,[1 3 2]),permute(mY,[3 1 2])),[],p);
imC = inv(mC);
out1 = zeros(n*m,1);
for ii = 1:n*m
    out1(ii) = A(ii,:)*imC*A(ii,:).';  %//'
end
out1 = reshape(out1,n,m);
>> out1
out1 =
      -8.4256       10.032       2.8929       7.1762
      -44.748      -4.3851      -13.645      -9.6702
      -4.5297       3.2928      0.11132       2.5998

相反,如果你有:

mD(j, i) = (mX(i, :) - mY(j, :)) * inv(mC) * (mX(i, :) - mY(j, :)).';

解决方案将转换为下一个列出的版本。

方法 #1

A = reshape(bsxfun(@minus,permute(mY,[1 3 2]),permute(mX,[3 1 2])),[],p);
out = reshape(diag(A*inv(mC)*A.'),m,n);

方法 #2

A = reshape(bsxfun(@minus,permute(mY,[1 3 2]),permute(mX,[3 1 2])),[],p);
imC = inv(mC);
out1 = zeros(m*n,1);
for i = 1:n*m
    out(i) = A(i,:)*imC*A(i,:).';  %//'
end
out = reshape(out,m,n);

样本运行-

>> n = 3; m = 4; p = 5;
mX = rand(n,p);    mY = rand(m,p);     mC = rand(p,p);  imC = inv(mC);
>> %// Original solution
for i = 1:n
    for j = 1:m
        mD(j, i) = (mX(i, :) - mY(j, :)) * inv(mC) * (mX(i, :) - mY(j, :)).'; %//'
    end
end
>> mD
mD =
      0.81755      0.33205      0.82254
       1.7086       1.3363       2.4209
      0.36495      0.78394     -0.33097
      0.17359       0.3889      -1.0624
>> %// Approach #1
A = reshape(bsxfun(@minus,permute(mY,[1 3 2]),permute(mX,[3 1 2])),[],p);
out = reshape(diag(A*inv(mC)*A.'),m,n);  %//'
>> out
out =
      0.81755      0.33205      0.82254
       1.7086       1.3363       2.4209
      0.36495      0.78394     -0.33097
      0.17359       0.3889      -1.0624
>> %// Approach #2
A = reshape(bsxfun(@minus,permute(mY,[1 3 2]),permute(mX,[3 1 2])),[],p);
imC = inv(mC);
out1 = zeros(m*n,1);
for i = 1:n*m
    out1(i) = A(i,:)*imC*A(i,:).';  %//'
end
out1 = reshape(out1,m,n);
>> out1
out1 =
      0.81755      0.33205      0.82254
       1.7086       1.3363       2.4209
      0.36495      0.78394     -0.33097
      0.17359       0.3889      -1.0624

减少到 L2

如果允许预处理矩阵mC并且不怕数值差异,马氏距离似乎可以减少到普通的L2距离。

首先,计算 mC:

的 Cholesky 分解
mR = chol(mC)      % C = R^t * R, where R is upper-triangular

现在我们可以使用这些因素重新制定马氏距离:

(Xi-Yj) * inv(C) * (Xi-Yj)^t = || (Xi-Yj) inv(R) ||^2 = ||TXi - TYj||^2
where:  TXi = Xi * inv(R)
        TYj = Yj * inv(R)

所以思路是先将点XiYj变换为TXiTYj,然后计算它们之间的欧氏距离。这是算法大纲:

  1. 计算 mR - 协方差矩阵 mC 的 Cholesky 因子(需要 O(p^3) 时间)。
  2. 逆三角矩阵mR(需要O(p^3)时间)。
  3. mXmY 乘以右边的 inv(mR)(需要 O(p^2 (m+n))时间)。
  4. 计算点对之间的平方 L2 距离(需要 O(m n p) 时间)。

总时间为 O(m n p + (m + n) p^2 + p^3) 与原始 O(m n p^2)。它应该在 1 << p << n,m 时工作得更快。在这种情况下,第 4 步将花费大部分时间,应该进行矢量化。

矢量化

我对 MATLAB 经验不多,但对 x86 CPU 上的 SIMD 向量化有相当多的了解。在原始计算中,沿着一个足够大的数组维度进行矢量化就足够了,并对其他维度进行简单的循环。

如果您希望 p 足够大,可以沿点的坐标向量化,并为 i <= nj <= m 制作两个嵌套循环。这类似于@Daniel 发布的内容。

如果 p 不够大,您可以改为沿其中一个点序列矢量化。这类似于@dpmcmlxxvi 发布的解决方案:您必须从第二个矩阵的所有行中减去一个矩阵的单行,然后计算结果行的平方范数。重复 n 次( 或 m 次)。

对于我来说,完全矢量化(这意味着在 MATLAB 中用矩阵运算而不是循环重写)听起来不像是一个聪明的性能目标。最有可能的部分矢量化解决方案速度最快。