从 n×2 numpy 数组填充 SortedLIst

Populating a SortedLIst from an n×2 numpy array

我有一个形状为 n×2 的 numpy 数组,一串长度为 2 的元组,我想将其传输到 SortedList。所以目标是创建一个带有长度为 2 的整数元组的 SortedList。

问题在于 SortedList 的构造函数检查每个条目的真值。这适用于一维数组:

In [1]: import numpy as np
In [2]: from sortedcontainers import SortedList
In [3]: a = np.array([1,2,3,4])
In [4]: SortedList(a)
Out[4]: SortedList([1, 2, 3, 4], load=1000)

但是对于二维,当每一项都是数组时,没有明确的真值,SortedList不配合:

In [5]: a.resize(2,2)
In [6]: a
Out[6]: 
array([[1, 2],
       [3, 4]])

In [7]: SortedList(a)
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
<ipython-input-7-7a4b2693bb52> in <module>()
----> 1 SortedList(a)

/home/me/miniconda3/envs/env/lib/python3.6/site-packages/sortedcontainers/sortedlist.py in __init__(self, iterable, load)
     81 
     82         if iterable is not None:
---> 83             self._update(iterable)
     84 
     85     def __new__(cls, iterable=None, key=None, load=1000):

/home/me/miniconda3/envs/env/lib/python3.6/site-packages/sortedcontainers/sortedlist.py in update(self, iterable)
    176         _lists = self._lists
    177         _maxes = self._maxes
--> 178         values = sorted(iterable)
    179 
    180         if _maxes:

ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()

我目前的解决方法是手动将每一行转换为一个元组:

sl = SortedList()
for t in np_array:
    x, y = t
    sl.add((x,y))

但是,此解决方案提供了一些改进空间。有没有人知道如何在不显式地将所有数组解包为元组的情况下解决这个问题?

问题不在于正在检查数组的真值,而是正在比较它们以便对它们进行排序。如果你在数组上使用比较运算符,你会得到数组:

>>> import numpy as np
>>> np.array([1, 4]) < np.array([2, 3])
array([ True, False], dtype=bool)

这个生成的布尔数组实际上是sorted检查其真值的数组。

另一方面,元组(或列表)的相同操作将逐个元素进行比较,return 单个布尔值:

>>> (1, 4) < (2, 3)
True
>>> (1, 4) < (1, 3)
False

因此,当 SortedList 尝试对 numpy 数组序列使用 sorted 时,它无法进行比较,因为它需要单个布尔值 return由比较运算符编辑。

对此进行抽象的一种方法是创建一个新数组 class,该数组实现 __eq____lt____gt__ 等比较运算符以重现元组的排序行为。具有讽刺意味的是,最简单的方法是将底层数组转换为元组,例如:

class SortableArray(object):

    def __init__(self, seq):
        self._values = np.array(seq)

    def __eq__(self, other):
        return tuple(self._values) == tuple(other._values)
        # or:
        # return np.all(self._values == other._values)

    def __lt__(self, other):
        return tuple(self._values) < tuple(other._values)

    def __gt__(self, other):
        return tuple(self._values) > tuple(other._values)

    def __le__(self, other):
        return tuple(self._values) <= tuple(other._values)

    def __ge__(self, other):
        return tuple(self._values) >= tuple(other._values)

    def __str__(self):
        return str(self._values)

    def __repr__(self):
        return repr(self._values)

通过此实现,您现在可以对 SortableArray 个对象的列表进行排序:

In [4]: ar1 = SortableArray([1, 3])

In [5]: ar2 = SortableArray([1, 4])

In [6]: ar3 = SortableArray([1, 3])

In [7]: ar4 = SortableArray([4, 5])

In [8]: ar5 = SortableArray([0, 3])

In [9]: lst1 = [ar1, ar2, ar3, ar4, ar5]

In [10]: lst1
Out[10]: [array([1, 3]), array([1, 4]), array([1, 3]), array([4, 5]), array([0, 3])]

In [11]: sorted(lst1)
Out[11]: [array([0, 3]), array([1, 3]), array([1, 3]), array([1, 4]), array([4, 5])]

这对于您的需要来说可能有点过头了,但这是实现它的一种方法。在任何情况下,您都不会在比较时没有 return 单个布尔值的对象序列上使用 sorted

如果您所追求的只是避免 for 循环,您可以将其替换为列表理解(即 SortedList([tuple(row) for row in np_array]))。