Matlab 中沿维度子集的多维数组的 Argmax

Argmax of a multidimensional array along a subset of dimensions in Matlab

比如说,Y 是一个 7 维数组,我需要一种有效的方法来沿着最后 3 个维度最大化它,这将在 GPU 上运行。 因此,我需要一个具有 Y 最大值的 4 维数组和三个在最后三个维度中具有这些值索引的 4 维数组。 我可以

[Y7, X7] = max(Y , [], 7);
[Y6, X6] = max(Y7, [], 6);
[Y5, X5] = max(Y6, [], 5);

那么我已经找到了值 (Y5) 和第 5 维 (X5) 的索引。但我仍然需要沿第 6 和第 7 个维度的索引。

这是一种方法。让 N 表示要最大化的维数。

  1. 重塑 Y 以将最后 N 个维度折叠成一个维度。
  2. 沿着折叠的维度最大化。这给出了 argmax 作为这些维度的线性索引。
  3. 将线性索引展开为 N 个子索引,每个维度一个。

以下代码适用于任意数量的维度(不一定像您的示例中的73)。为实现这一点,它一般处理 Y 的大小,并使用从元胞数组获得的逗号分隔列表从 sub2ind.

获得 N 输出
Y = rand(2,3,2,3,2,3,2); % example 7-dimensional array
N = 3; % last dimensions along which to maximize
D = ndims(Y);
sz = size(Y);
[~, ind] = max(reshape(Y, [sz(1:D-N) prod(sz(D-N+1:end))]), [], D-N+1);
sub = cell(1,N);
[sub{:}] = ind2sub(sz(D-N+1:D), ind);

作为检查,在运行上面的代码之后,观察例如Y(2,3,1,2,:)(为方便起见显示为行向量):

>> reshape(Y(2,3,1,2,:), 1, [])
ans =
    0.5621    0.4352    0.3672    0.9011    0.0332    0.5044    0.3416    0.6996    0.0610    0.2638    0.5586    0.3766

最大值显示为 0.9011,它出现在第 4 位置(其中 "position" 是沿 N=3 折叠维度定义的)。事实上,

>> ind(2,3,1,2)
ans =
     4
>> Y(2,3,1,2,ind(2,3,1,2))
ans =
    0.9011

或者,就 N=3 个子指数而言,

>> Y(2,3,1,2,sub{1}(2,3,1,2),sub{2}(2,3,1,2),sub{3}(2,3,1,2))
ans =
    0.9011