在 python 中从没有循环的数组中提取数组
Extract array from array without loop in python
我正在尝试从数组中提取数组的一部分。
假设我有一个数组 array1
,形状为 (M, N, P)
。对于我的具体情况,M = 10
、N = 5
、P = 2000
。我有另一个数组 array2
,形状为 (M, N, 1)
,它包含 array1
中沿最后一个轴的有趣数据的起点。我想从 array2
给出的索引开始提取 50 个数据点,有点像这样:
array1[:, :, array2:array2 + 50]
我希望得到形状 (M, N, 50)
的结果。不幸的是我得到了错误:
TypeError: only integer scalar arrays can be converted to a scalar index
当然我也可以通过遍历数组得到结果,但我觉得一定有更聪明的方法,因为我经常需要这个。
由于您在每个位置的索引未对齐,您可以创建一个掩码或花哨的索引来提取所需的元素。由于提取的值将是平坦的,因此您必须重新调整它们的形状。
创建遮罩的方法如下:
K = 50
mask = np.zeros((M, N, P + 1), dtype=np.int8)
np.put_along_axis(mask, array2, 1, axis=-1)
np.put_along_axis(mask, array2 + K, -1, axis=-1)
mask.cumsum(axis=-1, out=mask)
mask = mask[..., :-1].view(bool)
我们使用 np.int8
和 np.bool_
具有相同的项目大小,并且 np.cumsum
将初始掩码位置传播到每个轴的最终位置。
剩下的很简单:
array3 = array1[mask].reshape(M, N, K)
您可以绕过 np.put_along_axis
并在适当的地方使用直接索引和裁剪来避免构造掩码时的额外元素:
mask = np.zeros_like(array1, dtype=np.int8)
r = np.tile(np.arange(M)[:, None, None], (1, N, 1))
c = np.tile(np.arange(N)[None, :, None], (M, 1, 1))
clip_mask = array2 + K < P
mask[r, c, array2] = 1
mask[r[clip_mask], c[clip_mask], array2[clip_mask] + K] = -1
mask = np.cumsum(mask, axis=-1, out=mask).view(bool)
这一切都非常浪费:要获得一个形状为 (M, N, K)
的数组,您正在创建一个大小为 (M, N, P)
的布尔掩码以及一些大小为 (M, N, 1)
的索引数组,另一个大小为 (M, N, 1)
的掩码,然后是那些索引数组的一些掩码版本。在这里使用 for
循环真的没有错,只要你编译它们,例如用 cython 或 numba.
您可以通过比较 array2 中的值与最后一个维度的索引范围来构建掩码:
例如:
import numpy as np
M,N,P,k = 4,2,15,3 # yours would be 10,5,2000,50
A1 = np.arange(M*N*P).reshape((M,N,P))
A2 = np.arange(M*N).reshape((M,N,1)) + 1
rP = np.arange(P)[None,None,:]
A3 = A1[(rP>=A2)&(rP<A2+k)].reshape((M,N,k))
输入:
print(A1)
[[[ 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14]
[ 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29]]
[[ 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44]
[ 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59]]
[[ 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74]
[ 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89]]
[[ 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104]
[105 106 107 108 109 110 111 112 113 114 115 116 117 118 119]]]
print(A2)
[[[1]
[2]]
[[3]
[4]]
[[5]
[6]]
[[7]
[8]]]
输出:
print(A3)
[[[ 1 2 3]
[ 17 18 19]]
[[ 33 34 35]
[ 49 50 51]]
[[ 65 66 67]
[ 81 82 83]]
[[ 97 98 99]
[113 114 115]]]
我正在尝试从数组中提取数组的一部分。
假设我有一个数组 array1
,形状为 (M, N, P)
。对于我的具体情况,M = 10
、N = 5
、P = 2000
。我有另一个数组 array2
,形状为 (M, N, 1)
,它包含 array1
中沿最后一个轴的有趣数据的起点。我想从 array2
给出的索引开始提取 50 个数据点,有点像这样:
array1[:, :, array2:array2 + 50]
我希望得到形状 (M, N, 50)
的结果。不幸的是我得到了错误:
TypeError: only integer scalar arrays can be converted to a scalar index
当然我也可以通过遍历数组得到结果,但我觉得一定有更聪明的方法,因为我经常需要这个。
由于您在每个位置的索引未对齐,您可以创建一个掩码或花哨的索引来提取所需的元素。由于提取的值将是平坦的,因此您必须重新调整它们的形状。
创建遮罩的方法如下:
K = 50
mask = np.zeros((M, N, P + 1), dtype=np.int8)
np.put_along_axis(mask, array2, 1, axis=-1)
np.put_along_axis(mask, array2 + K, -1, axis=-1)
mask.cumsum(axis=-1, out=mask)
mask = mask[..., :-1].view(bool)
我们使用 np.int8
和 np.bool_
具有相同的项目大小,并且 np.cumsum
将初始掩码位置传播到每个轴的最终位置。
剩下的很简单:
array3 = array1[mask].reshape(M, N, K)
您可以绕过 np.put_along_axis
并在适当的地方使用直接索引和裁剪来避免构造掩码时的额外元素:
mask = np.zeros_like(array1, dtype=np.int8)
r = np.tile(np.arange(M)[:, None, None], (1, N, 1))
c = np.tile(np.arange(N)[None, :, None], (M, 1, 1))
clip_mask = array2 + K < P
mask[r, c, array2] = 1
mask[r[clip_mask], c[clip_mask], array2[clip_mask] + K] = -1
mask = np.cumsum(mask, axis=-1, out=mask).view(bool)
这一切都非常浪费:要获得一个形状为 (M, N, K)
的数组,您正在创建一个大小为 (M, N, P)
的布尔掩码以及一些大小为 (M, N, 1)
的索引数组,另一个大小为 (M, N, 1)
的掩码,然后是那些索引数组的一些掩码版本。在这里使用 for
循环真的没有错,只要你编译它们,例如用 cython 或 numba.
您可以通过比较 array2 中的值与最后一个维度的索引范围来构建掩码:
例如:
import numpy as np
M,N,P,k = 4,2,15,3 # yours would be 10,5,2000,50
A1 = np.arange(M*N*P).reshape((M,N,P))
A2 = np.arange(M*N).reshape((M,N,1)) + 1
rP = np.arange(P)[None,None,:]
A3 = A1[(rP>=A2)&(rP<A2+k)].reshape((M,N,k))
输入:
print(A1)
[[[ 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14]
[ 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29]]
[[ 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44]
[ 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59]]
[[ 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74]
[ 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89]]
[[ 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104]
[105 106 107 108 109 110 111 112 113 114 115 116 117 118 119]]]
print(A2)
[[[1]
[2]]
[[3]
[4]]
[[5]
[6]]
[[7]
[8]]]
输出:
print(A3)
[[[ 1 2 3]
[ 17 18 19]]
[[ 33 34 35]
[ 49 50 51]]
[[ 65 66 67]
[ 81 82 83]]
[[ 97 98 99]
[113 114 115]]]