计算数组中元素的最快方法是什么?

What is the fastest way to count elements in an array?

在我的模型中,重复次数最多的任务之一是计算数组中每个元素的数量。计数来自一个封闭的集合,所以我知道有 X 种类型的元素,并且全部或部分元素填充数组,以及代表 'empty' 单元格的零。该数组没有以任何方式排序,并且可能很长(大约 1M 个元素),并且该任务在一次模拟中完成了数千次(这也是数百次模拟的一部分)。结果应该是大小为 X 的向量 r,因此 r(k) 是数组中 k 的数量。

示例:

对于X = 9,如果我有以下输入向量:

v = [0 7 8 3 0 4 4 5 3 4 4 8 3 0 6 8 5 5 0 3]

我想得到这个结果:

r = [0 0 4 4 3 1 1 3 0]

注意我不想要零的个数,数组中没有出现的元素(比如2)在结果的对应位置有一个0向量 (r(2) == 0).

实现此目标的最快方法是什么?

tl;dr: 最快的方法取决于数组的大小。对于小于 214 的数组,下面的方法 3 (accumarray) 更快。对于大于下面方法 2 的数组 (histcounts) 更好。

更新:我也使用 2016b 引入的 implicit broadcasting 进行了测试,结果几乎与 bsxfun 方法相同,此方法没有显着差异(相对于其他方法)。


让我们看看执行此任务的可用方法有哪些。对于以下示例,我们假设 Xn 个元素,从 1 到 n,我们感兴趣的数组是 M,它是一个列数组,可以在尺寸。我们的结果向量将是 spp1,这样 spp(k) 就是 Mk 的数量。虽然我这里写的是X,但在下面的代码中并没有明确的实现,我只是定义了n = 500,而X隐含的是1:500.

天真的 for 循环

处理此任务的最简单直接的方法是通过 for 循环迭代 X 中的元素并计算 M 中等于它的元素数:

function spp = loop(M,n)
spp = zeros(n,1);
for k = 1:size(spp,1);
    spp(k) = sum(M==k); 
end
end

这当然不是那么聪明,特别是如果只有来自 X 的一小部分元素正在填充 M,所以我们最好先寻找那些已经在 M 中的元素:

function spp = uloop(M,n)
u = unique(M); % finds which elements to count
spp = zeros(n,1);
for k = u(u>0).';
    spp(k) = sum(M==k); 
end
end

通常,在 MATLAB 中,建议尽可能多地利用内置函数,因为大多数时候它们要快得多。我想到了 5 个选项:

1.函数 tabulate

函数 tabulate returns 一个非常方便的频率 table 乍一看似乎是这个任务的完美解决方案:

function tab = tabi(M)
tab = tabulate(M);
if tab(1)==0
    tab(1,:) = [];
end
end

唯一要做的修复是删除 table 的第一行,如果它计算 0 元素(可能是 M 中没有零) .

2。函数 histcounts

另一个可以很容易地根据我们的需要进行调整的选项 histcounts:

function spp = histci(M,n)
spp = histcounts(M,1:n+1);
end

这里,为了分别统计1到n之间的所有不同元素,我们将边定义为1:n+1,所以X中的每个元素都有自己的bin。我们也可以写 histcounts(M(M>0),'BinMethod','integers'),但我已经测试过了,它需要更多时间(尽管它使函数独立于 n)。

3.函数 accumarray

我将在这里带来的下一个选项是使用函数 accumarray:

function spp = accumi(M)
spp = accumarray(M(M>0),1);
end

这里我们给函数 M(M>0) 作为输入,跳过零,并使用 1 作为 vals 输入来计算所有唯一元素。

4.函数 bsxfun

我们甚至可以使用二元运算 @eq(即 ==)来查找每种类型的所有元素:

function spp = bsxi(M,n)
spp = bsxfun(@eq,M,1:n);
spp = sum(spp,1);
end

如果我们将第一个输入 M 和第二个 1:n 保持在不同的维度,所以一个是列向量另一个是行向量,那么函数比较 M1:n 中的每个元素,并创建一个 length(M)-by-n 逻辑矩阵,然后我们可以求和以获得所需的结果。

5.函数 ndgrid

另一个选项,类似于 bsxfun,是使用 ndgrid 函数显式创建所有可能性的两个矩阵:

function spp = gridi(M,n)
[Mx,nx] = ndgrid(M,1:n);
spp = sum(Mx==nx);
end

然后我们比较它们并对列求和,以获得最终结果。

基准测试

我做了一些测试,从上面提到的所有方法中找到最快的方法,我为所有路径定义了 n = 500。对于某些人(尤其是天真的 for),n 对执行时间有很大影响,但这不是这里的问题,因为我们想针对给定的 n 对其进行测试.

结果如下:

我们可以注意到几件事:

  1. 有趣的是,最快的方法发生了变化。对于小于 214 的数组,accumarray 是最快的。对于大于 214 的数组,histcounts 是最快的。
  2. 正如预期的那样,朴素的 for 循环在两个版本中都是最慢的,但是对于小于 28 的数组,"unique & for" 选项更慢。 ndgrid 在大于 211 的数组中成为最慢的,可能是因为需要在内存中存储非常大的矩阵。
  3. tabulate 处理小于 29 的数组的方式有些不规则。在我进行的所有试验中,这个结果是一致的(模式有一些变化)。

bsxfunndgrid 曲线是 t运行 的,因为它让我的电脑卡在更高的值,趋势已经很清楚了)

此外,请注意 y 轴在 log10 中,因此单位减少(例如大小为 219[=180= 的数组) ], 在 accumarrayhistcounts) 之间意味着操作速度提高了 10 倍。

我很高兴在评论中听到对此测试的改进,如果您有其他概念上不同的方法,非常欢迎您提出建议作为答案。

代码

以下是计时函数中包含的所有函数:

function out = timing_hist(N,n)
M = randi([0 n],N,1);
func_times = {'for','unique & for','tabulate','histcounts','accumarray','bsxfun','ndgrid';
    timeit(@() loop(M,n)),...
    timeit(@() uloop(M,n)),...
    timeit(@() tabi(M)),...
    timeit(@() histci(M,n)),...
    timeit(@() accumi(M)),...
    timeit(@() bsxi(M,n)),...
    timeit(@() gridi(M,n))};
out = cell2mat(func_times(2,:));
end

function spp = loop(M,n)
spp = zeros(n,1);
for k = 1:size(spp,1);
    spp(k) = sum(M==k); 
end
end

function spp = uloop(M,n)
u = unique(M);
spp = zeros(n,1);
for k = u(u>0).';
    spp(k) = sum(M==k); 
end
end

function tab = tabi(M)
tab = tabulate(M);
if tab(1)==0
    tab(1,:) = [];
end
end

function spp = histci(M,n)
spp = histcounts(M,1:n+1);
end

function spp = accumi(M)
spp = accumarray(M(M>0),1);
end

function spp = bsxi(M,n)
spp = bsxfun(@eq,M,1:n);
spp = sum(spp,1);
end

function spp = gridi(M,n)
[Mx,nx] = ndgrid(M,1:n);
spp = sum(Mx==nx);
end

这里是 运行 此代码并生成图表的脚本:

N = 25; % it is not recommended to run this with N>19 for the `bsxfun` and `ndgrid` functions.
func_times = zeros(N,5);
for n = 1:N
    func_times(n,:) = timing_hist(2^n,500);
end
% plotting:
hold on
mark = 'xo*^dsp';
for k = 1:size(func_times,2)
    plot(1:size(func_times,1),log10(func_times(:,k).*1000),['-' mark(k)],...
        'MarkerEdgeColor','k','LineWidth',1.5);
end
hold off
xlabel('Log_2(Array size)','FontSize',16)
ylabel('Log_{10}(Execution time) (ms)','FontSize',16)
legend({'for','unique & for','tabulate','histcounts','accumarray','bsxfun','ndgrid'},...
    'Location','NorthWest','FontSize',14)
grid on

1 取这个奇怪名字的原因来自于我的领域,生态学。我的模型是元胞自动机,通常在虚拟 space(上面的 M)中模拟个体生物。这些个体属于不同的物种(因此 spp),它们共同构成了所谓的 "ecological community"。群落的 "state" 由每个物种的个体数量给出,即此答案中的 spp 向量。在这个模型中,我们首先为要从中提取的个体定义一个物种库(X 以上),社区状态考虑到物种库中的所有物种,而不仅仅是 M 中存在的物种

我们知道输入向量总是包含整数,那么为什么不使用它来 "squeeze" 提高算法的性能呢?

我一直在尝试对两种最佳分箱方法进行一些优化 ,这就是我想出的:

  • 唯一值的数量(问题中的X,或示例中的n)应显式转换为(无符号)整数类型。
  • 计算一个额外的 bin 然后丢弃它比 "only process" 有效值更快(参见下面的 accumi_new 函数)。

此函数在我的机器上 运行 大约需要 30 秒。我正在使用 MATLAB R2016a。


function q38941694
datestr(now)
N = 25;
func_times = zeros(N,4);
for n = 1:N
    func_times(n,:) = timing_hist(2^n,500);
end
% Plotting:
figure('Position',[572 362 758 608]);
hP = plot(1:n,log10(func_times.*1000),'-o','MarkerEdgeColor','k','LineWidth',2);
xlabel('Log_2(Array size)'); ylabel('Log_{10}(Execution time) (ms)')
legend({'histcounts (double)','histcounts (uint)','accumarray (old)',...
  'accumarray (new)'},'FontSize',12,'Location','NorthWest')
grid on; grid minor;
set(hP([2,4]),'Marker','s'); set(gca,'Fontsize',16);
datestr(now)
end

function out = timing_hist(N,n)
% Convert n into an appropriate integer class:
if n < intmax('uint8')
  classname = 'uint8';
  n = uint8(n);
elseif n < intmax('uint16')
  classname = 'uint16';
  n = uint16(n);
elseif n < intmax('uint32')
  classname = 'uint32';
  n = uint32(n);
else % n < intmax('uint64')  
  classname = 'uint64';
  n = uint64(n);
end
% Generate an input:
M = randi([0 n],N,1,classname);
% Time different options:
warning off 'MATLAB:timeit:HighOverhead'
func_times = {'histcounts (double)','histcounts (uint)','accumarray (old)',...
  'accumarray (new)';
    timeit(@() histci(double(M),double(n))),...
    timeit(@() histci(M,n)),...
    timeit(@() accumi(M)),...
    timeit(@() accumi_new(M))
    };
out = cell2mat(func_times(2,:));
end

function spp = histci(M,n)
  spp = histcounts(M,1:n+1);
end

function spp = accumi(M)
  spp = accumarray(M(M>0),1);
end

function spp = accumi_new(M)
  spp = accumarray(M+1,1);
  spp = spp(2:end);
end