如何访问 N*M*M numpy 数组的下三角

How to access lower triangle of N*M*M numpy array

我有一个形状为 arr.shape = N,M,M.

的 numpy 数组

我想访问每个 M,M 数组的下三角。我尝试使用

arr1 = arr[:,np.tril_indices(M,-1)]
arr1 = arr[:][np.tril_indices(M,-1)]

等等,在第一种情况下内核死机,而在第二种情况下我收到一条错误消息:

   ---------------------------------------------------------------------------
IndexError                                Traceback (most recent call last)
<ipython-input-23-1b36c5b12706> in <module>
----> 1 arr1 = arr[:][np.tril_indices(M,-1)]

IndexError: index 6 is out of bounds for axis 0 with size 6

在哪里

N=6

澄清一下,我想找到每个 M,M 数组(N 个这样的实例)的下三角中的所有元素,并将结果保存在一个新数组中,形状如下:

arr1.shape = (N,(M*(M-1))/2)

编辑:

虽然 np.tril(arr) 有效,但它会产生一个数组

arr1 = np.tril(arr)
arr1.shape

#(N,M,M)

我希望生成的数组具有指定的形状,即我不想要数组的上部

谢谢

import numpy as np 
a = np.random.rand(2, 5, 5) 
#array([[[0.28212197, 0.29827562, 0.05151153, 0.90448236, 0.07521404],
#        [0.38938978, 0.67007919, 0.83561652, 0.5950061 , 0.73563179],
#        [0.77515285, 0.31973392, 0.91861436, 0.87386527, 0.85917542],
#        [0.12588184, 0.09173029, 0.28577701, 0.4884228 , 0.07183555],
#        [0.68656271, 0.19941039, 0.07924489, 0.15046004, 0.91011737]],
#
#       [[0.18662788, 0.45745028, 0.14557573, 0.22425571, 0.14204739],
#        [0.44502694, 0.85773626, 0.78554919, 0.07306402, 0.14608384],
#        [0.70620254, 0.81497515, 0.09397011, 0.32053184, 0.255485  ],
#        [0.50139688, 0.51539848, 0.24719375, 0.80708819, 0.39685176],
#        [0.94052069, 0.53927081, 0.39567362, 0.06065674, 0.53479994]]])

np.tril(a) 
#array([[[0.28212197, 0.        , 0.        , 0.        , 0.        ],
#        [0.38938978, 0.67007919, 0.        , 0.        , 0.        ],
#        [0.77515285, 0.31973392, 0.91861436, 0.        , 0.        ],
#        [0.12588184, 0.09173029, 0.28577701, 0.4884228 , 0.        ],
#        [0.68656271, 0.19941039, 0.07924489, 0.15046004, 0.91011737]],
#
#       [[0.18662788, 0.        , 0.        , 0.        , 0.        ],
#        [0.44502694, 0.85773626, 0.        , 0.        , 0.        ],
#        [0.70620254, 0.81497515, 0.09397011, 0.        , 0.        ],
#        [0.50139688, 0.51539848, 0.24719375, 0.80708819, 0.        ],
#        [0.94052069, 0.53927081, 0.39567362, 0.06065674, 0.53479994]]])

如果你想去除零点并将其展平为(2, 15)数组(注意每个下三角数组中有10个零点)-

a_no_zeros = np.array([el
for mat in a_lower
for row in mat
for el in row
if el > 0
]).reshape(2, 15)

使用 tri... 函数集时,检查源代码会很有用。它们都是 python,并且基于 np.tri

做一个小样本数组-来说明和验证答案:

In [205]: arr = np.arange(18).reshape(2,3,3)  # arange(1,19) might be better
In [206]: arr
Out[206]: 
array([[[ 0,  1,  2],
        [ 3,  4,  5],
        [ 6,  7,  8]],

       [[ 9, 10, 11],
        [12, 13, 14],
        [15, 16, 17]]])

tril 将上三角值设置为 0。它在这种情况下有效,但没有记录对 3d 数组的应用。

In [207]: np.tril(arr) 
Out[207]: 
array([[[ 0,  0,  0],
        [ 3,  4,  0],
        [ 6,  7,  8]],

       [[ 9,  0,  0],
        [12, 13,  0],
        [15, 16, 17]]])

但是在代码中 if 首先从最后 2 个维度构造一个布尔掩码:

In [208]: mask = np.tri(*arr.shape[-2:], dtype=bool)
In [209]: mask
Out[209]: 
array([[ True, False, False],
       [ True,  True, False],
       [ True,  True,  True]])

并使用 np.where 将一些值设置为 0。这在 3d 情况下通过广播工作。 maskarr 在最后两个维度上匹配,所以 mask 可以 broadcast 匹配:

In [210]: np.where(mask, arr, 0)
Out[210]: 
array([[[ 0,  0,  0],
        [ 3,  4,  0],
        [ 6,  7,  8]],

       [[ 9,  0,  0],
        [12, 13,  0],
        [15, 16, 17]]])

你的 tril_indices 就是这个掩码的索引:

In [217]: np.nonzero(mask)    # aka np.where
Out[217]: (array([0, 1, 1, 2, 2, 2]), array([0, 0, 1, 0, 1, 2]))
In [218]: np.tril_indices(3)
Out[218]: (array([0, 1, 1, 2, 2, 2]), array([0, 0, 1, 0, 1, 2]))

它们不能直接用于索引arr:

In [220]: arr[np.tril_indices(3)].shape
Traceback (most recent call last):
  File "<ipython-input-220-e26dc1f514cc>", line 1, in <module>
    arr[np.tril_indices(3)].shape
IndexError: index 2 is out of bounds for axis 0 with size 2

In [221]: arr[:,np.tril_indices(3)].shape
Out[221]: (2, 2, 6, 3)

但是解压两个索引数组:

In [222]: I,J = np.tril_indices(3)
In [223]: I,J
Out[223]: (array([0, 1, 1, 2, 2, 2]), array([0, 0, 1, 0, 1, 2]))
In [224]: arr[:,I,J]
Out[224]: 
array([[ 0,  3,  4,  6,  7,  8],
       [ 9, 12, 13, 15, 16, 17]])

布尔掩码也可以直接使用:

In [226]: arr[:,mask]
Out[226]: 
array([[ 0,  3,  4,  6,  7,  8],
       [ 9, 12, 13, 15, 16, 17]])

基础 np.tri 通过简单地在索引

上做一个外部 >= 来工作
In [231]: m = np.greater_equal.outer(np.arange(3),np.arange(3))
In [232]: m
Out[232]: 
array([[ True, False, False],
       [ True,  True, False],
       [ True,  True,  True]])
In [234]: np.arange(3)[:,None]>=np.arange(3)
Out[234]: 
array([[ True, False, False],
       [ True,  True, False],
       [ True,  True,  True]])