numpy 在这种特定情况下是否需要双换位?

numpy is double transposition necessary in this specific case?

我有一个数组

xx = np.arange(24).reshape(2, 12)

array([[ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11],
       [12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23]])

我想重塑它,得到

array([[[ 0,  1,  2,  3],
        [12, 13, 14, 15]],

       [[ 4,  5,  6,  7],
        [16, 17, 18, 19]],

       [[ 8,  9, 10, 11],
        [20, 21, 22, 23]]])

我可以通过

实现
xx.T.reshape(3, 4, 2).transpose(0, 2, 1)

但是要转两次,我觉得没必要。那么有人可以确认这是唯一的方法还是提供更具可读性的解决方案? 谢谢!

我会这样做:首先,生成两个数组(分开显示是为了分解):

xx.reshape(2, -1, 4)
# Output:
# array([[[ 0,  1,  2,  3],
#         [ 4,  5,  6,  7],
#         [ 8,  9, 10, 11]],
# 
#        [[12, 13, 14, 15],
#         [16, 17, 18, 19],
#         [20, 21, 22, 23]]])

从这里开始,我会沿着第二个维度堆叠,以便按照您的意愿组合它们:

np.stack(xx.reshape(2, -1, 4), axis=1)
# Output:
# array([[[ 0,  1,  2,  3],
#         [12, 13, 14, 15]],
# 
#        [[ 4,  5,  6,  7],
#         [16, 17, 18, 19]],
# 
#        [[ 8,  9, 10, 11],
#         [20, 21, 22, 23]]])

你会避免换位。希望它更具可读性,但最后,这是非常主观的,对吧? '^^

可以进行单个转置:

data = np.arange(24).reshape(2, 12)
data = data.reshape(2, 3, 4).transpose(1, 0, 2)

编辑:

我使用 itertools.permutationsitertools.product 检查了这个:

import itertools
import numpy as np

data = np.arange(24).reshape(2, 12)
desired_data = np.array([[[ 0,  1,  2,  3],
                          [12, 13, 14, 15]],
                         
                         [[ 4,  5,  6,  7],
                          [16, 17, 18, 19]],
                         
                         [[ 8,  9, 10, 11],
                          [20, 21, 22, 23]]])

shapes = [2, 3, 4]
transpose_dims = [0, 1, 2]

shape_permutations = itertools.permutations(shapes)
transpose_permutations = itertools.permutations(transpose_dims)

for shape, transpose in itertools.product(
    list(shape_permutations),
    list(transpose_permutations),
):
    
    new_data = data.reshape(*shape).transpose(*transpose)

    try:
        np.allclose(new_data, desired_data)
    except ValueError as e:
        pass
    else:
        break

print(f"{shape=}, {transpose=}")

shape=(2, 3, 4), transpose=(1, 0, 2)

除了@Paul 的回答之外,删除其中一个转置可以加快速度。时间增益约为 15%: