如何将 np.argmax 合并到广播中以替换给定的 for-loop 代码(如果可能)

How to incorporate np.argmax into broadcasting to replace the given for-loop code (if possible)

我一直在尝试使用 Numpy 来尽可能地加速代码。它真的非常快。但是,有时确实需要一些巧妙的思考。我想熟能生巧。

不是我解释我的问题,而是我要替换的内容:

# We use U and A to compute V

U = np.array([[1,2],
             [4,3],
             [5,6],
             [7,8]])

V = np.zeros(U.shape)

A = np.array([[1,3],
              [3,4]])

# The for loop to be replaced

for t in range(len(U)):
    V[t] = np.argmax( U[t]*A.T ,axis = 1)

我的尝试:

V = np.argmax(U[:,np.newaxis]*A.T,axis=1)

# U[:,np.newaxis]*A.T 

别害怕,我知道我的 Numpy 版本有什么问题。注释掉的代码确实给出了正确的中间值,但是,我没有像在我的 for 循环代码中那样正确地合并 np.argmax 部分。我认为也许不能。如果可能的话,请帮助我。非常感谢。

你走在正确的轨道上,只是方向错了。你也不需要转置 A:

(A * U[:,np.newaxis]).argmax(-1) # or equivalently axis 2

array([[1, 1],
       [1, 0],
       [1, 1],
       [1, 1]], dtype=int64)