Numpy:无法在列表中搜索元组

Numpy: Can't search for tuple in list

我正在尝试在列表中搜索元组

import numpy as np
l = [(0, 0), (1, 1), (3, 4)]
print(np.where(l == (0, 0)))
 >>>> (array([], dtype=int64),)

由于某种原因,这是不可能的。这对我来说很奇怪,因为 l[0] returns (0, 0)

首先,如果我们想使用 numpy,请确保我们在正确的 ndarray 上工作:

#I add some more data for test purposes
l = np.array([(0, 0), (1, 1), (3, 4), (3, 6), (0,0)]) 

> array([[0, 0],
         [1, 1],
         [3, 4],
         [3, 6],
         [0, 0]])

之后,我建议使用 np.nonzero 来找到符合特定条件的所有索引,问题是 numpy 似乎正在解包元组并为我提供每个元组的索引'元素如果元组尊重条件,为了避免这种情况,我们只选择索引的第一个维度并应用 np.unique 来摆脱重复项。 (我怀疑我显然缺少一个更优雅的解决方案)。

最后的结果是:

np.unique(np.array(list(zip(*np.nonzero(l == [0,0]))))[:,0:1])

> array([0, 4], dtype=int64)

请注意,正如评论者所建议的那样,解决方案在列表理解格式中可能更具可读性:

#We avoid using an ndarray since we will be using vanilla python anyway
l = [(0, 0), (1, 1), (3, 4), (3, 6), (0,0)]
l_indx = [i for i, tup in enumerate(l) if tup == (0,0)]

> [0, 4]

TL;DR

为了使用矢量化代码实现您所追求的目标,您需要将 l(或 (0, 0))转换为 NumPy 数组并使用 np.all(..., axis=1):

import numpy as np
l = [(0, 0), (1, 1), (3, 4)]
print(np.nonzero(np.all(np.array(l) == (0, 0), axis=1))[0])
# [0]

您可能应该退后一步,尝试理解您的代码的各个部分。

  1. 您的输入是 list。列表上的等于运算符 == 作用于整个列表,根据两个列表是否包含相同的元素,您将得到 TrueFalse
  2. np.where() 没有更多选项(应避免使用 np.nonzero())将查找数组中的 non-zero 条目和 return 这些条目的索引.

因此,标量布尔值上的 np.where() 将 return (array([0]),)(array([], dtype=int64),) 分别对应 TrueFalse

如果您要在 NumPy 数组上使用 ==,结果仍然会与您预期的不同,主要是因为 == 在这种情况下会起作用 element-wise(最终广播论据)。例如:

a = np.array([[0, 0], [1, 1], [2, 2]])
b = np.array([0, 1])
a == b
# [[ True False]
#  [False  True]
#  [False False]]

要使用矢量化代码获得您想要的内容,您需要找到值为 所有 True 的行,实际上 collapsing/reducing 一个 dimension/axis 与 np.all(..., axis=...),例如:

a = np.array([[0, 0], [1, 1], [2, 2]])
c = np.array([0, 0])
np.all(a == c, axis=1)
# [ True, False, False]

这基本上会检查轴 1 中的所有元素是否为 True

现在,您可以使用 np.nonzero() 查找索引:

np.nonzero(np.all(a == c, axis=1))[0]
# [0]

如果您不需要使用 NumPy,更简单的方法是循环遍历列表。

假设输入为:

l = [(0, 0), (1, 1), (3, 4)] * 2
x = (0, 0)

一个这样的实现可以是:

def search_all(items, x):
    for i, item in enumerate(items):
        if item == x:
            yield i


print(list(search_all(l, x))
# [0, 3]

如果列表足够稀疏,使用 list.index() 方法循环遍历列表可能会更快,这会强制某些迭代通过更快的路径。 这基本上是在 FlyingCircus index_all().

中实现的
def index_all(seq, item):
    i = 0
    try:
        while True:
            i = seq.index(item, i)
            yield i
            i += 1
    except ValueError:
        return


print(list(index_all(l, x))
# [0, 3]

以上两种方法都非常节省内存,而且速度相当快。 为了更快地实现,可以使用 Cython 来加速显式循环。

正如@KellyBundy 所建议的,甚至可以隐式地执行所有循环:

import itertools
import operator


def search_all_iter(seq, item):
    return itertools.compress(
        itertools.count(),
        map(operator.eq, seq, itertools.repeat(item)))


print(list(search_all_iter(l, x))
# [0, 3]

如果输入是 NumPy 数组(因此不需要转换为 NumPy 数组),我们可以从 al/ax:

开始
import numpy as np


al = np.array(l)
ax = np.array(x)

然后可以设计矢量化方法(但是需要相当大的临时对象),例如 find_index_np()(借鉴自 ) or find_index_unique() (borrowed from ):

def find_index_np(arr, subarr):
    return np.nonzero(np.all(arr == subarr, axis=1))[0]


print(find_index_np(al, ax))
# [0 3]
def find_index_unique(arr, subarr):
    return np.unique(np.array(list(zip(*np.nonzero(arr == subarr))))[:, :1])


print(find_index_unique(al, ax))
# [0 3]

或者,可以使用 Numba 加速编写与 NumPy 数组配合良好的内存高效且快速的方法:

@nb.njit
def all_eq(a, b):
    for x, y in zip(a, b):
        if x != y:
            return False
    return True


@nb.njit
def find_index_nb(arr, subarr):
    result = np.empty(arr.size, dtype=np.int_)
    j = 0
    for i, x in enumerate(arr):
        if all_eq(x, subarr):
            result[j] = i
            j += 1
    return result[:j].copy()


print(find_index_np(al, ax))
# [0 3]

最后,请为以上找到一些基准:

l = ([(0, 0)] + [(1, 1)] * 9) * 1000
x = (0, 0)
al = np.array(l)
ax = np.array(x)


%timeit list(index_all(l, x))
# 1000 loops, best of 5: 500 µs per loop
%timeit list(search_all(l, x))
# 1000 loops, best of 5: 911 µs per loop
%timeit list(search_all_iter(l, x))
# 1000 loops, best of 5: 701 µs per loop
%timeit find_index_np(al, ax)
# 1000 loops, best of 5: 247 µs per loop
%timeit find_index_unique(al, ax)
# 1000 loops, best of 5: 1.41 ms per loop
%timeit find_index_nb(al, ax)
# 1000 loops, best of 5: 465 µs per loop

请注意,所有这些方法的时间复杂度都是 O(N)(它们的计算时间随着输入的大小线性增加),但它们的相对速度在很大程度上也取决于找到被搜索项目的频率,例如如果我们将 x/ax 的出现次数从每十次增加到每两次出现一次,那么 index_all() 就会变得比 search_all():

l = ([(0, 0)] + [(1, 1)]) * 5000
x = (0, 0)
al = np.array(l)
ax = np.array(x)


%timeit list(index_all(l, x))
# 1000 loops, best of 5: 1.22 ms per loop
%timeit list(search_all(l, x))
# 1000 loops, best of 5: 1.01 ms per loop
%timeit list(search_all_iter(l, x))
# 1000 loops, best of 5: 666 µs per loop
%timeit find_index_np(al, ax)
# 1000 loops, best of 5: 250 µs per loop
%timeit find_index_unique(al, ax)
# 100 loops, best of 5: 6.91 ms per loop
%timeit find_index_nb(al, ax)
# 1000 loops, best of 5: 483 µs per loop