如何在 4d numpy 数组中沿 2d 获取最大值的索引
How to get the index of maximum values along 2d in a 4d numpy array
假设您有一个 3d 数组,例如
A = np.array([[[1,2], [2,3] ], [[3,4], [1,2]], [[2,3], [3,4]]])
A.shape (3, 2, 2)
array([[[1, 2],
[2, 3]],
[[3, 4],
[1, 2]],
[[2, 3],
[3, 4]]])
现在,如果我想获得沿第一个维度的最大值的索引,这很容易,因为
A.argmax(axis=0)
array([[1, 1],
[2, 2]])
如何对 4d 数组执行相同的操作,在前两个维度上找到最大值?
示例 4d 数组
np.random.randint(0, 9,[2,3,2,2])
B = array([[[[5, 4],
[3, 8]],
[[3, 5],
[0, 4]],
[[0, 1],
[3, 0]]],
[[[0, 2],
[7, 3]],
[[7, 3],
[8, 0]],
[[8, 3],
[2, 7]]]])
B.shape (2, 3, 2, 2)
在那种情况下,输出应该仍然是一个(2, 2)
矩阵,但是每个单元格包含一个维度0和维度1的最大索引的元组,即
示例输出
array([[(1, 2), (0, 1)],
[(1, 1), (0, 0)]])
解决方案
正如@hpaulj 正确指出的那样,这个问题的解决方案在于调用
idx = B.reshape(6,2,2).argmax(axis=0)
2d_indices = np.unravel_index(idx,(2,3))
但是,这种排序方式与我预期的有点不同(并在我上面的问题中概述)。要获得与上面完全相同的输出,只需添加
list_of_tuples = [a for a in zip(2d_indices[0].flatten(), 2d_indices[1].flatten())]
np.array(list_of_tuples).reshape(3, 3, -1)
我觉得有点麻烦,可能有更直接的方法,但为了完整起见,我还是想 post :)
argmax
只接受标量轴值(其他一些函数允许元组)。但是我们可以重塑 B
:
In [18]: B.reshape(6,2,2).argmax(axis=0)
Out[18]:
array([[5, 1],
[4, 0]])
并将它们转换回二维索引:
In [21]: np.unravel_index(_18,(2,3))
Out[21]:
(array([[1, 0],
[1, 0]]),
array([[2, 1],
[1, 0]]))
这些值可以通过多种方式重新排序,例如:
In [23]: np.transpose(_21)
Out[23]:
array([[[1, 2],
[1, 1]],
[[0, 1],
[0, 0]]])
假设您有一个 3d 数组,例如
A = np.array([[[1,2], [2,3] ], [[3,4], [1,2]], [[2,3], [3,4]]])
A.shape (3, 2, 2)
array([[[1, 2],
[2, 3]],
[[3, 4],
[1, 2]],
[[2, 3],
[3, 4]]])
现在,如果我想获得沿第一个维度的最大值的索引,这很容易,因为
A.argmax(axis=0)
array([[1, 1],
[2, 2]])
如何对 4d 数组执行相同的操作,在前两个维度上找到最大值?
示例 4d 数组
np.random.randint(0, 9,[2,3,2,2])
B = array([[[[5, 4],
[3, 8]],
[[3, 5],
[0, 4]],
[[0, 1],
[3, 0]]],
[[[0, 2],
[7, 3]],
[[7, 3],
[8, 0]],
[[8, 3],
[2, 7]]]])
B.shape (2, 3, 2, 2)
在那种情况下,输出应该仍然是一个(2, 2)
矩阵,但是每个单元格包含一个维度0和维度1的最大索引的元组,即
示例输出
array([[(1, 2), (0, 1)],
[(1, 1), (0, 0)]])
解决方案
正如@hpaulj 正确指出的那样,这个问题的解决方案在于调用
idx = B.reshape(6,2,2).argmax(axis=0)
2d_indices = np.unravel_index(idx,(2,3))
但是,这种排序方式与我预期的有点不同(并在我上面的问题中概述)。要获得与上面完全相同的输出,只需添加
list_of_tuples = [a for a in zip(2d_indices[0].flatten(), 2d_indices[1].flatten())]
np.array(list_of_tuples).reshape(3, 3, -1)
我觉得有点麻烦,可能有更直接的方法,但为了完整起见,我还是想 post :)
argmax
只接受标量轴值(其他一些函数允许元组)。但是我们可以重塑 B
:
In [18]: B.reshape(6,2,2).argmax(axis=0)
Out[18]:
array([[5, 1],
[4, 0]])
并将它们转换回二维索引:
In [21]: np.unravel_index(_18,(2,3))
Out[21]:
(array([[1, 0],
[1, 0]]),
array([[2, 1],
[1, 0]]))
这些值可以通过多种方式重新排序,例如:
In [23]: np.transpose(_21)
Out[23]:
array([[[1, 2],
[1, 1]],
[[0, 1],
[0, 0]]])