np.where 具有任意数量的条件

np.where with arbitrary number of conditions

问题

这道题:Numpy where function multiple conditions asks how to use np.where with two conditions. This answer suggests to use the & operator between conditions, which works if we have a low number of conditions which can be typed. This answer suggests using the np.logical_and,只能有两个参数。

本帖:也讨论了np.where的多个条件,但条件的数量是事先知道的。

我正在寻找一种方法来评估 np.where 表达式,而无需事先知道条件的数量。


可重现的设置

我有一个二维数组:

arr = \
np.array([[1,2,3,4],
          [4,5,6,7],
          [9,8,7,6],
          [0,1,0,1],
          [9,7,6,5]])

Select 具有例如大于 5 的索引 1 元素、大于 3 的索引 2 元素的行。为此,我这样做:

res = arr[np.where((arr[:,1]>5) & (arr[:,2]>4))]

res 则为:

array([[9, 8, 7, 6],
       [9, 7, 6, 5]])

符合预期。

但是如果我将这些条件列为列表呢?上面的例子是:

cols = [1,2] # arbitrary length list
tholds = [5,4] # arbitrary length list

这两个列表的长度事先未知,但长度相同

如何使用 colstholds 列表获得 res


我试过的

使用ast.literal_eval定义:

filterstring = "&".join([f"(pdist[:,{col}]>{th})" for col, th in zip(cols,tholds)])

计算结果为 (pdist[:,1]>5)&(pdist[:,2]>4),即我们在手动输入条件时在 np.where() 中得到的结果。

但是ast.literal_eval(f"np.where({filterstring})")报错:

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
<ipython-input-269-1aaff20de82f> in <module>()
----> 1 ast.literal_eval(f"np.where({filterstring})")

3 frames
/usr/lib/python3.7/ast.py in _convert_num(node)
     53         elif isinstance(node, Num):
     54             return node.n
---> 55         raise ValueError('malformed node or string: ' + repr(node))
     56     def _convert_signed_num(node):
     57         if isinstance(node, UnaryOp) and isinstance(node.op, (UAdd, USub)):

ValueError: malformed node or string: <_ast.Call object at 0x7f41daa21f10>

所以这没有用。 to the question 确认这不是正确的方法。


编辑:

连续使用 np.wheres 的建议对于这个特定的例子很有效,但并不是我真正想要的。我想调用 np.where 一次,而不是多次只评估一个条件。

试图避免评估。它有一些安全隐患。

您可以像这样反复进行

def unknown_conditions(arr, cols, tholds):

    for col, thold in zip(cols, tholds):
        arr = arr[np.where(arr[:, col] > thold)]
    
    return arr

您可以累加满足条件的次数,然后调用一次np.where函数。从那里可以很容易地根据条件混合 and/or 组合。

(概念上与学院的建议非常相似。)

def filter_by_conditions(arr, cols, tholds):
    n_conditions = len(cols)
    bool_accumulator = np.zeros(arr.shape[0])
    for c, t in zip(cols, tholds):
        bool_accumulator += (arr[:, c] > t).astype(int)

    return arr[np.where(bool_accumulator) == n_conditions]

您可以通过将数组列的重新排序视图(这是“使用索引列表”的一种奇特的说法)与广播比较相结合来完成此操作,使用 np.all 减少行数

>>> arr[np.where(np.all(arr[:,cols] > thds, axis=1))]
array([[9, 8, 7, 6],
       [9, 7, 6, 5]])

如您的第一个 link 所示(并且如 documentation for np.where 顶部的注释中所述),在这种情况下实际上不需要 np.where;它只会减慢速度。您可以使用布尔列表对 Numpy 数组进行切片,因此您无需将布尔列表更改为索引列表。由于 np.all& 运算符一样,returns 是布尔值的 Numpy 数组,因此也不需要 np.asarraynp.nonzero(如上述注释):

>>> arr[np.all(arr[:,cols] > thds, axis=1)]
array([[9, 8, 7, 6],
       [9, 7, 6, 5]])