Numpy:Reshape/horizontally 将 3D 数组拆分为 4D 数组

Numpy: Reshape/horizontally split 3D array into 4D array

我有这样的 3D np.array:

arr3d = np.arange(36).reshape(3, 2, 6)

array([[[ 0,  1,  2,  3,  4,  5],
        [ 6,  7,  8,  9, 10, 11]],

       [[12, 13, 14, 15, 16, 17],
        [18, 19, 20, 21, 22, 23]],

       [[24, 25, 26, 27, 28, 29],
        [30, 31, 32, 33, 34, 35]]])

我需要将 arr3d 的每个窗格水平拆分为 3 个块,例如:

np.array(np.hsplit(arr3d[0, :, :], 3))

array([[[ 0,  1],
        [ 6,  7]],

       [[ 2,  3],
        [ 8,  9]],

       [[ 4,  5],
        [10, 11]]])

这应该会导致一个 4D 数组。

arr4d[0, :, :, :] 应包含原始 3D 数组的第一个窗格的新拆分 3D 数组 (np.array(np.hsplit(arr3d[0, :, :], 3)))

最终结果应该是这样的:

result = np.array(
    [
        [[[0, 1], [6, 7]], [[2, 3], [8, 9]], [[4, 5], [10, 11]]],
        [[[12, 13], [18, 19]], [[14, 15], [20, 21]], [[16, 17], [22, 23]]],
        [[[24, 25], [30, 31]], [[26, 27], [32, 33]], [[28, 29], [34, 35]]],
    ]
)

result.shape
(3, 3, 2, 2)

array([[[[ 0,  1],
         [ 6,  7]],

        [[ 2,  3],
         [ 8,  9]],

        [[ 4,  5],
         [10, 11]]],


       [[[12, 13],
         [18, 19]],

        [[14, 15],
         [20, 21]],

        [[16, 17],
         [22, 23]]],


       [[[24, 25],
         [30, 31]],

        [[26, 27],
         [32, 33]],

        [[28, 29],
         [34, 35]]]])

我正在寻找一种 pythonic 方式来执行此 reshaping/splitting。

尝试:

sh = arr3d.shape[:-1] + (3, -1)
arr4d = arr3d.reshape(*sh).swapaxes(1, 2)

>>> arr4d
array([[[[ 0,  1],
         [ 6,  7]],

        [[ 2,  3],
         [ 8,  9]],

        [[ 4,  5],
         [10, 11]]],


       [[[12, 13],
         [18, 19]],

        [[14, 15],
         [20, 21]],

        [[16, 17],
         [22, 23]]],


       [[[24, 25],
         [30, 31]],

        [[26, 27],
         [32, 33]],

        [[28, 29],
         [34, 35]]]])

说明

这是您要拆分为 (3, -1) 的最后一个维度(在您的示例中为大小 6)。这就是我们首先重塑为 (a, b, 3, -1) 的原因(其中 (a, b, _)arr3d 的形状)。但是因为你对每一行做了 hsplit(),那么你想要的实际形状是 (a, 3, b, -1),所以我们需要交换轴 1 和轴 2(更准确地说:滚动它们,我们将在下面看到更高的维度)。

另一个例子

shape = 7, 2, 3*3
arr3d = np.arange(np.prod(shape)).reshape(*shape)
check = np.array([np.array(np.hsplit(arr3d[k], 3)) for k in range(shape[0)])

sh = arr3d.shape[:-1] + (3, -1)
arr4d = arr3d.reshape(*sh).swapaxes(1, 2)
>>> np.equal(arr4d, check).all()
True

泛化到更高维度

shape = 4, 5, 2, 3*3
ar = np.arange(np.prod(shape)).reshape(*shape)
check = np.array([np.array(np.split(ar[k], 3, axis=-1)) for k in range(shape[0])])

# any dimension
sh = ar.shape[:-1] + (3, -1)
out = np.rollaxis(ar.reshape(*sh), -2, 1)
>>> np.equal(out, check).all()
True