argmax 索引相对于行

argmax index relatively to row

我有以下 numpy 数组:

a = array([[0.25077832, 0.42227767, 0.43744429],
           [0.28539526, 0.44163316, 0.40298769],
           [0.35807141, 0.2856717 , 0.33536935],
           [0.55462028, 0.53807624, 0.38644028],
           [0.18301549, 0.26485082, 0.1366992 ],
           [0.26986122, 0.41347949, 0.43940707],
           [0.25237023, 0.25293316, 0.32400949],
           [0.73319735, 0.70901891, 0.42607705],
           [0.43295485, 0.57804534, 0.48060421],
           [0.32985147, 0.40481195, 0.27833032],
           [0.70750743, 0.65494225, 0.31475339],
           [0.31292153, 0.28715622, 0.37811341],
           [1.        , 0.88283385, 0.42551691],
           [0.59371996, 0.58705522, 0.43142012],
           [0.73690667, 0.63809236, 0.34826675],
           [0.36728384, 0.48038203, 0.52019776],
           [0.67126132, 0.56541162, 0.39211614],
           [0.13064738, 0.17167055, 0.11997984],
           [0.41040586, 0.38921193, 0.37591999],
           [0.25267195, 0.30835439, 0.34994184],
           [0.62562474, 0.57021446, 0.34875994],
           [0.88283382, 1.        , 0.44805126],
           [0.45318085, 0.58569116, 0.51384947],
           [0.23267012, 0.4015728 , 0.41292012],
           [0.5935659 , 0.62583221, 0.36990236],
           [0.62760122, 0.55182008, 0.44280949],
           [0.55419133, 0.48549239, 0.23645478],
           [0.66452635, 0.6102609 , 0.4181133 ]])

当我尝试 a.argmax() 时,我得到 36,这是正确的,因为 12 行由 36 元素组成,值为 1.。但我需要索引 0 ,这意味着 12

中的 0 元素

预期输出:(12, 0)

IIUC,可以使用np.divmod:

np.divmod(a.argmax(), a.shape[1])

np.nditer:

list(np.ndindex(a.shape))[a.argmax()]

或者,按照@Ali_Sh的建议: np.unravel_index 适用于任意数量的维度!

np.unravel_index(a.argmax(), a.shape)

输出:(12, 0)