花式索引一个 numpy 矩阵:每行一个元素

fancy indexing a numpy matrix: one element per row

我有一个 2d numpy 数组,矩阵,形状为 (m, n)。我的实际用例有 m ~ 1e5 和 n ~ 100,但为了有一个简单的最小示例:

matrix = np.arange(5*3).reshape((5, 3))

我有一个整数索引数组 idx,形状为 (m, ),每个条目都在 [0, n) 之间。该数组指定应从 矩阵 的每一行中 select 编辑哪一列。

idx = np.array([2, 0, 2, 1, 1])

因此,我正在尝试 select 第 0 行的第 2 列、第 1 行的第 0 列、第 2 行的第 2 列、第 1 行的第 1 列和第 4 行的第 1 列。因此,最终答案应该是:

correct_result = np.array((2, 3, 8, 10, 13))

我试过以下方法,很直观,但不正确:

incorrect_result = matrix[:, idx]

上述语法的作用是将 idx 逐行应用为奇特的索引数组,从而产生另一个形状为 (m, n) 的矩阵,这不是我想要的。

这种类型的奇特索引的正确语法是什么?

correct_result = matrix[np.arange(m), idx]

高级索引表达式 matrix[I, J] 给出了 output[n] == matrix[I[n], J[n]].

的输出

如果我们想要output[n] == matrix[n, idx[n]],那么我们需要I[n] == nJ[n] == idx[n],所以我们需要Inp.arange(m)J成为 idx.