从堆叠矩阵中获取批量索引 - Python Jax

Get batched indices from stacked matrices - Python Jax

我想提取堆叠矩阵的索引。

假设我们有一个维度为 (3, 2, 4) 的数组 a,这意味着我们有三个维度为 (2,4) 的数组和一个索引列表 (3, 2) .

def get_cols(x,idx):  
  x = x[:,idx]
  return x


idx = jnp.array([[0,1],[2,3],[1,2]])

a = jnp.array([[[1,2,3,4],
            [3,2,2,4]],
           
           [[100,20,3,50],
            [5,5,2,4]],
                         
           [[1,2,3,4],
            [3,2,2,4]]
           ])



e = jax.vmap(get_cols, in_axes=(None,0))(a,idx)

我想提取给定一批索引的不同矩阵的列。我希望得到以下结果:

e = [[[[1,2],
  [3,2]],

  [[100,20],
  [5,5]],

  [[1,2],
  [3,2]]],
 
 
 
 [[[3,4],
  [2,4]],
  
  [[3,50],
  [2,4]],
  
  [[3,4],
  [2,4]]],
 
 
 
 
[[[2,3],
[2,2]],
           
[[20,3],
 [5,2]],
                         
[[2,3],
[2,2]]]]

我错过了什么?

您似乎对输入的双 vmap 感兴趣;例如像这样:

e = jax.vmap(jax.vmap(get_cols, in_axes=(0, None)), in_axes=(None, 0))(a, idx)
print(e)
[[[[  1   2]
   [  3   2]]

  [[100  20]
   [  5   5]]

  [[  1   2]
   [  3   2]]]


 [[[  3   4]
   [  2   4]]

  [[  3  50]
   [  2   4]]

  [[  3   4]
   [  2   4]]]


 [[[  2   3]
   [  2   2]]

  [[ 20   3]
   [  5   2]]

  [[  2   3]
   [  2   2]]]]