如何使用每行的索引矩阵索引矩阵的行元素?
How to index row elements of a Matrix with a Matrix of indices for each row?
我有一个索引矩阵,例如
I = np.array([[1, 0, 2], [2, 1, 0]])
第 i 行的索引 select 是第 i 行中另一个矩阵 M 的一个元素。
所以有 M 例如
M = np.array([[6, 7, 8], [9, 10, 11])
M[I] 应该 select:
[[7, 6, 8], [11, 10, 9]]
我可以:
I1 = np.repeat(np.arange(0, I.shape[0]), I.shape[1])
I2 = np.ravel(I)
Result = M[I1, I2].reshape(I.shape)
但这看起来很复杂,我正在寻找更优雅的解决方案。最好不要压扁和整形。
示例中我使用了numpy,但实际上我使用的是jax。所以如果jax有更高效的解决方案,欢迎分享。
这一行代码怎么样?这个想法是枚举矩阵的行和行索引,所以你可以访问索引矩阵中的相应行。
import numpy as np
I = np.array([[1, 0, 2], [2, 1, 0]])
M = np.array([[6, 7, 8], [9, 10, 11]])
Result = np.array([row[I[i]] for i, row in enumerate(M)])
print(Result)
输出:
[[ 7 6 8]
[11 10 9]]
In [108]: I = np.array([[1, 0, 2], [2, 1, 0]])
...: M = np.array([[6, 7, 8], [9, 10, 11]])
...:
...: I,M
我必须在 M 中添加一个 ']'。
Out[108]:
(array([[1, 0, 2],
[2, 1, 0]]),
array([[ 6, 7, 8],
[ 9, 10, 11]]))
高级索引 broadcasting
:
In [110]: M[np.arange(2)[:,None],I]
Out[110]:
array([[ 7, 6, 8],
[11, 10, 9]])
第一个索引具有形状 (2,1),它与 I
到 select 的 (2,3) 形状的 (2,3) 值块配对。
np.take_along_axis
也可以在这里使用 M
的值,使用索引 I
over axis=1
:
>>> np.take_along_axis(M, I, axis=1)
array([[ 7, 6, 8],
[11, 10, 9]])
我有一个索引矩阵,例如
I = np.array([[1, 0, 2], [2, 1, 0]])
第 i 行的索引 select 是第 i 行中另一个矩阵 M 的一个元素。
所以有 M 例如
M = np.array([[6, 7, 8], [9, 10, 11])
M[I] 应该 select:
[[7, 6, 8], [11, 10, 9]]
我可以:
I1 = np.repeat(np.arange(0, I.shape[0]), I.shape[1])
I2 = np.ravel(I)
Result = M[I1, I2].reshape(I.shape)
但这看起来很复杂,我正在寻找更优雅的解决方案。最好不要压扁和整形。
示例中我使用了numpy,但实际上我使用的是jax。所以如果jax有更高效的解决方案,欢迎分享。
这一行代码怎么样?这个想法是枚举矩阵的行和行索引,所以你可以访问索引矩阵中的相应行。
import numpy as np
I = np.array([[1, 0, 2], [2, 1, 0]])
M = np.array([[6, 7, 8], [9, 10, 11]])
Result = np.array([row[I[i]] for i, row in enumerate(M)])
print(Result)
输出:
[[ 7 6 8]
[11 10 9]]
In [108]: I = np.array([[1, 0, 2], [2, 1, 0]])
...: M = np.array([[6, 7, 8], [9, 10, 11]])
...:
...: I,M
我必须在 M 中添加一个 ']'。
Out[108]:
(array([[1, 0, 2],
[2, 1, 0]]),
array([[ 6, 7, 8],
[ 9, 10, 11]]))
高级索引 broadcasting
:
In [110]: M[np.arange(2)[:,None],I]
Out[110]:
array([[ 7, 6, 8],
[11, 10, 9]])
第一个索引具有形状 (2,1),它与 I
到 select 的 (2,3) 形状的 (2,3) 值块配对。
np.take_along_axis
也可以在这里使用 M
的值,使用索引 I
over axis=1
:
>>> np.take_along_axis(M, I, axis=1)
array([[ 7, 6, 8],
[11, 10, 9]])