numpy 中按元素操作的问题

Problem with element-wise operations in numpy

这是我的第一个问题,如果我可以做得更好,请告诉我。

我正在尝试在两个数组之间进行逐元素操作,但广播无法像我希望的那样工作。

我有一个形状为 (N,4) 的数组。

square_list = np.array([[1,2,255,255], [255,255,4,4], [255,255,8,8], [255,255,16,16], [255,255,8,4], [255,1,8,8], [1,255,8,8]], dtype='B')

我还有一个形状为 (4,) 的数组。

square = np.array([1, 8, 8, 1], dtype='B')

我能做的是将我的 squaresquare_list 中的每个元素进行比较,并且它按预期被广播为形状 (N,4)。

现在我想将我的 square 在每个可能的轮换中与 square_list 进行比较。我写了一个函数,它 returns 一个形状数组 (4,4),其中包含每个可能的旋转。

square.rotations
    array([[1, 8, 8, 1],
           [1, 1, 8, 8],
           [8, 1, 1, 8],
           [8, 8, 1, 1]], dtype=uint8)

我知道如何使用循环来做到这一点。但是,我更愿意使用 returns 我想要的形状的元素运算符。

我得到的:

rotations & square_list
ValueError: operands could not be broadcast together with shapes (4,4) (6,4)

我想得到什么:

rotations & square_list
array([[[1, 0, 8, 1],
        [1, 8, 0, 0],
        [1, 8, 8, 0],
        [1, 8, 0, 0],
        [1, 8, 8, 0],
        [1, 0, 8, 0]],
       [[1, 0, 8, 8],
        [1, 1, 0, 0],
        [1, 1, 8, 8],
        [1, 1, 0, 0],
        [1, 1, 8, 0],
        [1, 1, 8, 8]],
       [[0, 0, 1, 8],
        [8, 1, 0, 0],
        [8, 1, 0, 8],
        [8, 1, 0, 0],
        [8, 1, 0, 0],
        [8, 1, 0, 8],
        [0, 1, 0, 8]],
       [[0, 0, 1, 1],
        [8, 8, 0, 0],
        [8, 8, 0, 0],
        [8, 8, 0, 0],
        [8, 8, 0, 0],
        [8, 0, 0, 0],
        [0, 8, 0, 0]]], dtype=uint8)

这只是为了形象化我想要的东西,我并不特别关心轴的顺序'。 (4, N, 4) 或 (N, 4, 4) 的形状都很好。 我觉得这可以通过重塑输入数组之一轻松实现,但我想不通。

提前致谢!

为旋转添加额外维度:

square_list & rotations[:,None]

输出:

array([[[1, 0, 8, 1],
        [1, 8, 0, 0],
        [1, 8, 8, 0],
        [1, 8, 0, 0],
        [1, 8, 8, 0],
        [1, 0, 8, 0],
        [1, 8, 8, 0]],

       [[1, 0, 8, 8],
        [1, 1, 0, 0],
        [1, 1, 8, 8],
        [1, 1, 0, 0],
        [1, 1, 8, 0],
        [1, 1, 8, 8],
        [1, 1, 8, 8]],

       [[0, 0, 1, 8],
        [8, 1, 0, 0],
        [8, 1, 0, 8],
        [8, 1, 0, 0],
        [8, 1, 0, 0],
        [8, 1, 0, 8],
        [0, 1, 0, 8]],

       [[0, 0, 1, 1],
        [8, 8, 0, 0],
        [8, 8, 0, 0],
        [8, 8, 0, 0],
        [8, 8, 0, 0],
        [8, 0, 0, 0],
        [0, 8, 0, 0]]], dtype=uint8)