如何使用numpy快速select一个二维矩阵中的子矩阵?

How to quickly select a sub matrix in a 2-dimensional matrix using numpy?

我有一个 7×7 矩阵,我不想使用循环快速切出子矩阵。

matrix= array([[ 0,  1,  2,  3,  4,  5,  6],
   [ 7,  8,  9, 10, 11, 12, 13],
   [14, 15, 16, 17, 18, 19, 20],
   [21, 22, 23, 24, 25, 26, 27],
   [28, 29, 30, 31, 32, 33, 34],
   [35, 36, 37, 38, 39, 40, 41],
   [42, 43, 44, 45, 46, 47, 48]])

sub_matrix = array([[1,2,3], [16,17,18], [28,29,30], [39,40,41]])

其实这个矩阵很大。我有每行的切片列表和每列开头的切片列表。很难直接指定所有行的列切片列表。

我试过下面的方法,但是报错:IndexError: shape mismatch: indexing arrays could not be broadcast together with shapes (4,) (4,3)

slice_row = [0, 2, 4, 5]
slice_col_start = [1,2,0,4]
interval = 3
slice_col = [np.arange(x,x+interval) for x in slice_col_start]

matrix[slice_row, np.r_[slice_col]]

如果你有你可以做的指数:

x = np.array([[1,2,3], [2,3,4], [0,1,2], [4,5,6]])
y = np.array([0, 2, 4, 5])
matrix[y[:,None], x]

输出:

array([[ 1,  2,  3],
       [16, 17, 18],
       [28, 29, 30],
       [39, 40, 41]])

可以通过np.take_along_axis来实现。如果给出 cols 数组:

rows = np.array([0, 2, 4, 5], dtype=np.int32)
cols = np.array([[1,2,3], [2,3,4], [0,1,2], [4,5,6]])
result = np.take_along_axis(a[rows], cols, axis=1)

感谢 Kevin,我想出了一个解决方案

import numpy as np
matrix = np.arange(7*7).reshape(7,7)
slice_row = np.array([0, 2, 4, 5])
slice_col_start = np.array([1,2,0,4])
interval = 3
slice_col = [np.arange(x,x+interval).tolist() for x in slice_col_start]

sub_matrix =matrix[slice_row[:,None], slice_col]
print(sub_matrix)

输出

[[ 1  2  3]
 [16 17 18]
 [28 29 30]
 [39 40 41]]
In [11]: arr = np.arange(49).reshape(7,7)
In [12]: slice_row = [0, 2, 4, 5]
    ...: slice_col_start = [1,2,0,4]
    ...: interval = 3
In [13]: idx1 = np.array(slice_row)
In [14]: idx2 = np.array(slice_col_start)

由于间隔是固定的,我们可以使用linspace一次调用创建所有列索引:

In [19]: idx3 = np.linspace(idx2,idx2+interval, interval, endpoint=False,dtype=int)
In [20]: idx3
Out[20]: 
array([[1, 2, 0, 4],
       [2, 3, 1, 5],
       [3, 4, 2, 6]])

那么就是索引的问题了:

In [21]: arr[idx1[:,None], idx3.T]
Out[21]: 
array([[ 1,  2,  3],
       [16, 17, 18],
       [28, 29, 30],
       [39, 40, 41]])

或使用广播加法:

In [23]: idx2[:,None] + np.arange(3)
Out[23]: 
array([[1, 2, 3],
       [2, 3, 4],
       [0, 1, 2],
       [4, 5, 6]])

如果间隔因行而异,我们将不得不使用迭代形式来获取列索引的完整列表。