argmax 索引相对于行
argmax index relatively to row
我有以下 numpy 数组:
a = array([[0.25077832, 0.42227767, 0.43744429],
[0.28539526, 0.44163316, 0.40298769],
[0.35807141, 0.2856717 , 0.33536935],
[0.55462028, 0.53807624, 0.38644028],
[0.18301549, 0.26485082, 0.1366992 ],
[0.26986122, 0.41347949, 0.43940707],
[0.25237023, 0.25293316, 0.32400949],
[0.73319735, 0.70901891, 0.42607705],
[0.43295485, 0.57804534, 0.48060421],
[0.32985147, 0.40481195, 0.27833032],
[0.70750743, 0.65494225, 0.31475339],
[0.31292153, 0.28715622, 0.37811341],
[1. , 0.88283385, 0.42551691],
[0.59371996, 0.58705522, 0.43142012],
[0.73690667, 0.63809236, 0.34826675],
[0.36728384, 0.48038203, 0.52019776],
[0.67126132, 0.56541162, 0.39211614],
[0.13064738, 0.17167055, 0.11997984],
[0.41040586, 0.38921193, 0.37591999],
[0.25267195, 0.30835439, 0.34994184],
[0.62562474, 0.57021446, 0.34875994],
[0.88283382, 1. , 0.44805126],
[0.45318085, 0.58569116, 0.51384947],
[0.23267012, 0.4015728 , 0.41292012],
[0.5935659 , 0.62583221, 0.36990236],
[0.62760122, 0.55182008, 0.44280949],
[0.55419133, 0.48549239, 0.23645478],
[0.66452635, 0.6102609 , 0.4181133 ]])
当我尝试 a.argmax()
时,我得到 36,这是正确的,因为 12
行由 36
元素组成,值为 1.
。但我需要索引 0
,这意味着 12
行
中的 0
元素
预期输出:(12, 0)
IIUC,可以使用np.divmod
:
np.divmod(a.argmax(), a.shape[1])
list(np.ndindex(a.shape))[a.argmax()]
或者,按照@Ali_Sh的建议:
np.unravel_index
适用于任意数量的维度!
np.unravel_index(a.argmax(), a.shape)
输出:(12, 0)
我有以下 numpy 数组:
a = array([[0.25077832, 0.42227767, 0.43744429],
[0.28539526, 0.44163316, 0.40298769],
[0.35807141, 0.2856717 , 0.33536935],
[0.55462028, 0.53807624, 0.38644028],
[0.18301549, 0.26485082, 0.1366992 ],
[0.26986122, 0.41347949, 0.43940707],
[0.25237023, 0.25293316, 0.32400949],
[0.73319735, 0.70901891, 0.42607705],
[0.43295485, 0.57804534, 0.48060421],
[0.32985147, 0.40481195, 0.27833032],
[0.70750743, 0.65494225, 0.31475339],
[0.31292153, 0.28715622, 0.37811341],
[1. , 0.88283385, 0.42551691],
[0.59371996, 0.58705522, 0.43142012],
[0.73690667, 0.63809236, 0.34826675],
[0.36728384, 0.48038203, 0.52019776],
[0.67126132, 0.56541162, 0.39211614],
[0.13064738, 0.17167055, 0.11997984],
[0.41040586, 0.38921193, 0.37591999],
[0.25267195, 0.30835439, 0.34994184],
[0.62562474, 0.57021446, 0.34875994],
[0.88283382, 1. , 0.44805126],
[0.45318085, 0.58569116, 0.51384947],
[0.23267012, 0.4015728 , 0.41292012],
[0.5935659 , 0.62583221, 0.36990236],
[0.62760122, 0.55182008, 0.44280949],
[0.55419133, 0.48549239, 0.23645478],
[0.66452635, 0.6102609 , 0.4181133 ]])
当我尝试 a.argmax()
时,我得到 36,这是正确的,因为 12
行由 36
元素组成,值为 1.
。但我需要索引 0
,这意味着 12
行
0
元素
预期输出:(12, 0)
IIUC,可以使用np.divmod
:
np.divmod(a.argmax(), a.shape[1])
list(np.ndindex(a.shape))[a.argmax()]
或者,按照@Ali_Sh的建议:
np.unravel_index
适用于任意数量的维度!
np.unravel_index(a.argmax(), a.shape)
输出:(12, 0)