Numpy 高级索引使用

Numpy advanced indexing usage

案例 1(已解决): 数组 A 的形状(比如)(300,50)。数组 B 是一个形状为 (300,5) 的索引数组,这样 B[i,j] 表示行 i 的另一行的索引到 "concate" 行旁边 i.最终结果是一个形状为 (300,5,50) 的数组 C,因此 C[i,j,:] = A[B[i,j],:]。这可以通过调用 A[B,:].

来完成

这是案例 1 的小脚本示例:

import numpy as np

## A is the data array
A = np.arange(20).reshape((5,4))
## B indicate for each row which rows to pull together
B = np.array([[0,2],[1,2],[2,0],[3,4],[4,1]])
A[B,:] #The desired result

情况2(未解决):同样的问题,只是现在A的形状是(100,300,50)。如果 B 是形状为 (100,300,5) 的索引矩阵,则最终结果将是形状为 (100,300,5,50) 的数组 C,使得 C[i,j,k,:] = A[i,B[i,j,k],:]A[B,:] 不再起作用,因为它由于广播而产生形状 (100,300,5,300,50)。

我应该如何使用索引来解决这个问题?

一种方法是重塑 2D 保持列数不变,然后使用扁平的 B 索引索引到第一个轴,最后重塑回所需的轴。

因此,实施将是 -

A.reshape(-1,A.shape[-1])[B.ravel()].reshape(100,300,5,50)

那些仅仅对数组进行视图的重塑,应该是非常有效的。

这两种情况都解决了。这是案例 #1 -

的示例 运行

1) 输入:

In [667]: A = np.random.rand(3,4)
     ...: B = np.random.randint(0,3,(3,5))
     ...: 

2) 原始方法:

In [668]: A[B,:]
Out[668]: 
array([[[ 0.1 ,  0.91,  0.1 ,  0.98],
        [ 0.1 ,  0.91,  0.1 ,  0.98],
        [ 0.1 ,  0.91,  0.1 ,  0.98],
        [ 0.45,  0.16,  0.02,  0.02],
        [ 0.1 ,  0.91,  0.1 ,  0.98]],

       [[ 0.45,  0.16,  0.02,  0.02],
        [ 0.48,  0.6 ,  0.96,  0.21],
        [ 0.48,  0.6 ,  0.96,  0.21],
        [ 0.1 ,  0.91,  0.1 ,  0.98],
        [ 0.45,  0.16,  0.02,  0.02]],

       [[ 0.48,  0.6 ,  0.96,  0.21],
        [ 0.45,  0.16,  0.02,  0.02],
        [ 0.48,  0.6 ,  0.96,  0.21],
        [ 0.45,  0.16,  0.02,  0.02],
        [ 0.45,  0.16,  0.02,  0.02]]])

3) 建议方法:

In [669]: A.reshape(-1,A.shape[-1])[B.ravel()].reshape(3,5,4)
Out[669]: 
array([[[ 0.1 ,  0.91,  0.1 ,  0.98],
        [ 0.1 ,  0.91,  0.1 ,  0.98],
        [ 0.1 ,  0.91,  0.1 ,  0.98],
        [ 0.45,  0.16,  0.02,  0.02],
        [ 0.1 ,  0.91,  0.1 ,  0.98]],

       [[ 0.45,  0.16,  0.02,  0.02],
        [ 0.48,  0.6 ,  0.96,  0.21],
        [ 0.48,  0.6 ,  0.96,  0.21],
        [ 0.1 ,  0.91,  0.1 ,  0.98],
        [ 0.45,  0.16,  0.02,  0.02]],

       [[ 0.48,  0.6 ,  0.96,  0.21],
        [ 0.45,  0.16,  0.02,  0.02],
        [ 0.48,  0.6 ,  0.96,  0.21],
        [ 0.45,  0.16,  0.02,  0.02],
        [ 0.45,  0.16,  0.02,  0.02]]])