MATLAB 中 all(a(:,i)==a,1) 的更快替代方案

A faster alternative to all(a(:,i)==a,1) in MATLAB

这是一个简单的问题:在 MATLAB 中是否有比 all(a(:,i)==a,1) 更快的替代方法?

我正在考虑一种在整个过程中受益于短路评估的实现。我的意思是,all() 肯定会受益于短路评估,但 a(:,i)==a 不会。

我尝试了下面的代码,

% example for the input matrix

m = 3;       % m and n aren't necessarily equal to those values.
n = 5000;    % It's only possible to know in advance that 'm' << 'n'.

a = randi([0,5],m,n); % the maximum value of 'a' isn't necessarily equal to 
                      % 5 but it's possible to state that every element in 
                      % 'a' is a positive integer.

% all, equal solution

tic
for i = 1:n % stepping up the elapsed time in orders of magnitude
    %%%%%%%%%% all and equal solution %%%%%%%%%
    ax_boo = all(a(:,i)==a,1);
    %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
end
toc

% alternative solution

tic
for i = 1:n % stepping up the elapsed time in orders of magnitude
    %%%%%%%%%%% alternative solution %%%%%%%%%%%
    ax_boo = a(1,i) == a(1,:);
    for k = 2:m
        ax_boo(ax_boo) = a(k,i) == a(k,ax_boo);
    end
    %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
end
toc

但直观的是,MATLAB 环境中的任何“for-loop-solution”自然会变慢。我想知道是否有用更快的语言编写的 MATLAB 内置函数。

编辑:

经过 运行 次更多测试后,我发现隐式扩展确实会对评估 a(:,i)==a 的性能产生影响。如果矩阵 a 有多于一行,all(repmat(a(:,i),[1,n])==a,1) 可能比 all(a(:,i)==a,1) 更快,具体取决于列数 (n)。对于 n=5000,事实证明 repmat 显式扩展更快。

但我认为,如果 a 的所有元素都是正整数,Kenneth Boyd 的回答的概括就是“最终解决方案”。我不会以原始形式处理 a (m x n 矩阵),而是存储和处理 adec (1 x n 矩阵):

exps = ((0):(m-1)).';
base = max(a,[],[1,2]) + 1;
adec = sum( a .* base.^exps , 1 );

换句话说,每一列将被编码为一个整数。当然 adec(i)==adecall(a(:,i)==a,1) 快。

编辑 2:

我忘了提到 adec 方法有功能限制。充其量,将 adec 存储为 uint64,以下不等式必须成立 base^m < 2^64 + 1.

作为替代方案,您可以使用 unique 的第三个输出:

[~, ~, iu] = unique(a.', 'rows');

for i = 1:n
  ax_boo = iu(i) == iu;
end

如评论中所示:

ax_boo isolates the indices of the columns I have to sum in a row vector b. So, basically the next line would be something like c = sum(b(ax_boo),2);

这是accumarray的典型用法:

[~, ~, iu] = unique(a.', 'rows');
C = accumarray(iu,b);
for i = 1:n
  c = C(i);
end

由于您的目标是计算匹配的列数,我的示例将二进制编码转换为整数小数,然后您只需遍历可能的值(3 行是 8 个可能的值)并计算数字匹配项。

a_dec = 2.^(0:(m-1)) * a;
num_poss_values = 2 ^ m;
num_matches = zeros(num_poss_values, 1);
for i = 1:num_poss_values
   num_matches(i) = sum(a_dec == (i - 1));
end

在我的电脑上,使用 2020a,以下是前 2 个选项和上面代码的执行时间:

Elapsed time is 0.246623 seconds.
Elapsed time is 0.553173 seconds.
Elapsed time is 0.000289 seconds.

所以我的代码快了 853 倍!

我编写了我的代码,因此它可以使用 m 作为任意整数。

num_matches 变量包含转换为小数后加起来为 0、1、2、...7 的列数。