如何根据一列的值将 numpy 数组拆分为子数组

how to split a numpy array into subarrays based on values of one colums

我有一个很大的 numpy 数组,想拆分它。我读过 但它帮不了我。目标列可以有多个值,但我知道我想根据哪个值来拆分它。在我的简化示例中,目标列是第三列,我想根据值 2. 拆分它。这是我的数组。

import numpy as np
big_array = np.array([[0., 10., 2.],
                      [2., 6., 2.],
                      [3., 1., 7.1],
                      [3.3, 6., 7.8],
                      [4., 5., 2.],
                      [6., 6., 2.],
                      [7., 1., 2.],
                      [8., 5., 2.1]])

具有此值的行 (2.) 进行一次拆分。然后,不是 2. 的下一行(第三和第四行)再做一个。在我的数据集中,我再次看到这个值 (2.) 并从中拆分出来,我再次将非 2. 值(最后一行)保留为另一个拆分。最终结果应如下所示:

spl_array = [np.array([[0., 10., 2.],
                       [2., 6., 2.]]),
             np.array([[3., 1., 7.1],
                      [3.3, 6., 7.8]]),
             np.array([[4., 5., 2.],
                      [6., 6., 2.],
                      [7., 1., 2.]]),
             np.array([[8., 5., 2.1]])]

提前感谢任何帮助。

首先你找到所有包含 2 或不包含 2 的数组。这个数组将充满 True 和 False 值。将此数组转换为包含 0 和 1 的数组。检查哪里有差异(比如 [0, 0, 1, 1, 0] 将是:0, 1, 0, -1.

基于这一变化,可以使用 numpy where 找到这些值的索引。

为大数组插入索引 0 和最后一个索引,这样您就可以将它们压缩到左右切片中。

import numpy as np
big_array = np.array([[0., 10., 2.],
                      [2., 6., 2.],
                      [3., 1., 7.1],
                      [3.3, 6., 7.8],
                      [4., 5., 2.],
                      [6., 6., 2.],
                      [7., 1., 2.],
                      [8., 5., 2.1]])
idx = [2 in array for array in big_array]
idx *= np.ones(len(idx))
slices = list(np.where(np.diff(idx) != 0)[0] + 1)
slices.insert(0,0)
slices.append(len(big_array))

result = list()
for left, right in zip(slices[:-1], slices[1:]):
    result.append(big_array[left:right])

'''
[array([[ 0., 10.,  2.],
        [ 2.,  6.,  2.]]),
 array([[3. , 1. , 7.1],
        [3.3, 6. , 7.8]]),
 array([[4., 5., 2.],
        [6., 6., 2.],
        [7., 1., 2.]]),
 array([[8. , 5. , 2.1]])]
'''

您可以使用 numpy

np.split(
    big_array,
    np.flatnonzero(np.diff(big_array[:,2] == 2) != 0) + 1
)

输出

[array([[ 0., 10.,  2.],
        [ 2.,  6.,  2.]]),
 array([[3. , 1. , 7.1],
        [3.3, 6. , 7.8]]),
 array([[4., 5., 2.],
        [6., 6., 2.],
        [7., 1., 2.]]),
 array([[8. , 5. , 2.1]])]