使用 argmax 时如何检测 numpy 数组中的平局

How to detect a tie in a numpy array when using argmax

如果我有一个如下所示的数组,我如何在使用 np.argmax() 时检测到至少有 3 个或更多值的平局?

examp = np.array([[4, 0, 1, 4, 4],
                  [5, 5, 1, 5, 5],
                  [1, 2, 2, 4, 1],
                  [4, 6, 1, 2, 4],
                  [1, 4, 3, 3, 3]])

np.argmax(examp, axis=1)

给出输出:

array([0, 0, 3, 1, 1]

以第一行为例,有“三平”。 4 的 3 个值。 np.argmax returns 具有最大值的第一个索引。但是,我如何才能检测到存在“三向平局”并让它使用自定义函数决定决胜局(前提是至少存在“三向平局”?

因此,第一行:看到有 4 的“3 路平局”。自定义函数运行,以便它可以决定决胜局。

第二行:“4 路平局”同样的事情发生了。

第三行:只有“2路平局”,低于至少“3路平局”的条件。可以默认为np.argmax.

你是正确的 np.argmax 只会找到 第一个 最大值。尽管您可以计算出其中有多少 argmax 存在,并根据您的逻辑得出该数字

indices = examp.argmax(0)
counts = (examp == examp[indices, np.r_[:3]]).sum(0)
# the same as
counts = np.count_nonzero(examp == examp[indices, np.r_[:3]], axis=0)

会return

indices = array([0, 3, 2])
counts = array([4, 1, 2])

找到第 n 个最大值的一种方法是 np.partition (or np.argpartition)。在这种情况下,您可以这样做:

>>> n = 3  # Size of tie
>>> i = examp.argpartition([-n, -1], axis=-1)

倒数第三列和最后一列中的值保证按正确的排序顺序排列(因此倒数第二列也是如此,但仅在这种有限的情况下)。如果这两个值彼此相等,则您有一个三向平局:

>>> r = np.arange(examp.shape[0])
>>> examp[r, i[:, -n]] == examp[r, i[:, -1]]
array([ True,  True, False, False, False])

您还可以使用 np.diff 来计算掩码:

>>> np.diff(examp[r[:, None], i[:, [-n, -1]]], axis=1) == 0
array([[ True],
       [ True],
       [False],
       [False],
       [False]])

您可以使用 np.take_along_axis 而不是第一个索引 r:

来获得类似的结果
>>> np.diff(np.take_along_axis(examp, i[:, -n::n-1], 1), axis=1) == 0
array([[ True],
       [ True],
       [False],
       [False],
       [False]])

在所有这些情况下,argmax 的值只是 i[:, -1],因为这是数组中最大值的索引。

由于您已经在使用 numpy,我强烈建议您也矢量化自定义打破平局函数。我在这里提供了作为掩码的输出,这样您就可以尽可能高效地做到这一点。