Numpy:无法在列表中搜索元组
Numpy: Can't search for tuple in list
我正在尝试在列表中搜索元组
import numpy as np
l = [(0, 0), (1, 1), (3, 4)]
print(np.where(l == (0, 0)))
>>>> (array([], dtype=int64),)
由于某种原因,这是不可能的。这对我来说很奇怪,因为 l[0]
returns (0, 0)
首先,如果我们想使用 numpy,请确保我们在正确的 ndarray 上工作:
#I add some more data for test purposes
l = np.array([(0, 0), (1, 1), (3, 4), (3, 6), (0,0)])
> array([[0, 0],
[1, 1],
[3, 4],
[3, 6],
[0, 0]])
之后,我建议使用 np.nonzero
来找到符合特定条件的所有索引,问题是 numpy 似乎正在解包元组并为我提供每个元组的索引'元素如果元组尊重条件,为了避免这种情况,我们只选择索引的第一个维度并应用 np.unique
来摆脱重复项。
(我怀疑我显然缺少一个更优雅的解决方案)。
最后的结果是:
np.unique(np.array(list(zip(*np.nonzero(l == [0,0]))))[:,0:1])
> array([0, 4], dtype=int64)
请注意,正如评论者所建议的那样,解决方案在列表理解格式中可能更具可读性:
#We avoid using an ndarray since we will be using vanilla python anyway
l = [(0, 0), (1, 1), (3, 4), (3, 6), (0,0)]
l_indx = [i for i, tup in enumerate(l) if tup == (0,0)]
> [0, 4]
TL;DR
为了使用矢量化代码实现您所追求的目标,您需要将 l
(或 (0, 0)
)转换为 NumPy 数组并使用 np.all(..., axis=1)
:
import numpy as np
l = [(0, 0), (1, 1), (3, 4)]
print(np.nonzero(np.all(np.array(l) == (0, 0), axis=1))[0])
# [0]
您可能应该退后一步,尝试理解您的代码的各个部分。
- 您的输入是
list
。列表上的等于运算符 ==
作用于整个列表,根据两个列表是否包含相同的元素,您将得到 True
或 False
。
np.where()
没有更多选项(应避免使用 np.nonzero()
)将查找数组中的 non-zero 条目和 return 这些条目的索引.
因此,标量布尔值上的 np.where()
将 return (array([0]),)
或 (array([], dtype=int64),)
分别对应 True
或 False
。
如果您要在 NumPy 数组上使用 ==
,结果仍然会与您预期的不同,主要是因为 ==
在这种情况下会起作用 element-wise(最终广播论据)。例如:
a = np.array([[0, 0], [1, 1], [2, 2]])
b = np.array([0, 1])
a == b
# [[ True False]
# [False True]
# [False False]]
要使用矢量化代码获得您想要的内容,您需要找到值为 所有 True
的行,实际上 collapsing/reducing 一个 dimension/axis 与 np.all(..., axis=...)
,例如:
a = np.array([[0, 0], [1, 1], [2, 2]])
c = np.array([0, 0])
np.all(a == c, axis=1)
# [ True, False, False]
这基本上会检查轴 1 中的所有元素是否为 True
。
现在,您可以使用 np.nonzero()
查找索引:
np.nonzero(np.all(a == c, axis=1))[0]
# [0]
如果您不需要使用 NumPy,更简单的方法是循环遍历列表。
假设输入为:
l = [(0, 0), (1, 1), (3, 4)] * 2
x = (0, 0)
一个这样的实现可以是:
def search_all(items, x):
for i, item in enumerate(items):
if item == x:
yield i
print(list(search_all(l, x))
# [0, 3]
如果列表足够稀疏,使用 list.index()
方法循环遍历列表可能会更快,这会强制某些迭代通过更快的路径。
这基本上是在 FlyingCircus index_all()
.
中实现的
def index_all(seq, item):
i = 0
try:
while True:
i = seq.index(item, i)
yield i
i += 1
except ValueError:
return
print(list(index_all(l, x))
# [0, 3]
以上两种方法都非常节省内存,而且速度相当快。
为了更快地实现,可以使用 Cython 来加速显式循环。
正如@KellyBundy 所建议的,甚至可以隐式地执行所有循环:
import itertools
import operator
def search_all_iter(seq, item):
return itertools.compress(
itertools.count(),
map(operator.eq, seq, itertools.repeat(item)))
print(list(search_all_iter(l, x))
# [0, 3]
如果输入是 NumPy 数组(因此不需要转换为 NumPy 数组),我们可以从 al
/ax
:
开始
import numpy as np
al = np.array(l)
ax = np.array(x)
然后可以设计矢量化方法(但是需要相当大的临时对象),例如 find_index_np()
(借鉴自 ) or find_index_unique()
(borrowed from ):
def find_index_np(arr, subarr):
return np.nonzero(np.all(arr == subarr, axis=1))[0]
print(find_index_np(al, ax))
# [0 3]
def find_index_unique(arr, subarr):
return np.unique(np.array(list(zip(*np.nonzero(arr == subarr))))[:, :1])
print(find_index_unique(al, ax))
# [0 3]
或者,可以使用 Numba 加速编写与 NumPy 数组配合良好的内存高效且快速的方法:
@nb.njit
def all_eq(a, b):
for x, y in zip(a, b):
if x != y:
return False
return True
@nb.njit
def find_index_nb(arr, subarr):
result = np.empty(arr.size, dtype=np.int_)
j = 0
for i, x in enumerate(arr):
if all_eq(x, subarr):
result[j] = i
j += 1
return result[:j].copy()
print(find_index_np(al, ax))
# [0 3]
最后,请为以上找到一些基准:
l = ([(0, 0)] + [(1, 1)] * 9) * 1000
x = (0, 0)
al = np.array(l)
ax = np.array(x)
%timeit list(index_all(l, x))
# 1000 loops, best of 5: 500 µs per loop
%timeit list(search_all(l, x))
# 1000 loops, best of 5: 911 µs per loop
%timeit list(search_all_iter(l, x))
# 1000 loops, best of 5: 701 µs per loop
%timeit find_index_np(al, ax)
# 1000 loops, best of 5: 247 µs per loop
%timeit find_index_unique(al, ax)
# 1000 loops, best of 5: 1.41 ms per loop
%timeit find_index_nb(al, ax)
# 1000 loops, best of 5: 465 µs per loop
请注意,所有这些方法的时间复杂度都是 O(N)
(它们的计算时间随着输入的大小线性增加),但它们的相对速度在很大程度上也取决于找到被搜索项目的频率,例如如果我们将 x
/ax
的出现次数从每十次增加到每两次出现一次,那么 index_all()
就会变得比 search_all()
:
慢
l = ([(0, 0)] + [(1, 1)]) * 5000
x = (0, 0)
al = np.array(l)
ax = np.array(x)
%timeit list(index_all(l, x))
# 1000 loops, best of 5: 1.22 ms per loop
%timeit list(search_all(l, x))
# 1000 loops, best of 5: 1.01 ms per loop
%timeit list(search_all_iter(l, x))
# 1000 loops, best of 5: 666 µs per loop
%timeit find_index_np(al, ax)
# 1000 loops, best of 5: 250 µs per loop
%timeit find_index_unique(al, ax)
# 100 loops, best of 5: 6.91 ms per loop
%timeit find_index_nb(al, ax)
# 1000 loops, best of 5: 483 µs per loop
我正在尝试在列表中搜索元组
import numpy as np
l = [(0, 0), (1, 1), (3, 4)]
print(np.where(l == (0, 0)))
>>>> (array([], dtype=int64),)
由于某种原因,这是不可能的。这对我来说很奇怪,因为 l[0]
returns (0, 0)
首先,如果我们想使用 numpy,请确保我们在正确的 ndarray 上工作:
#I add some more data for test purposes
l = np.array([(0, 0), (1, 1), (3, 4), (3, 6), (0,0)])
> array([[0, 0],
[1, 1],
[3, 4],
[3, 6],
[0, 0]])
之后,我建议使用 np.nonzero
来找到符合特定条件的所有索引,问题是 numpy 似乎正在解包元组并为我提供每个元组的索引'元素如果元组尊重条件,为了避免这种情况,我们只选择索引的第一个维度并应用 np.unique
来摆脱重复项。
(我怀疑我显然缺少一个更优雅的解决方案)。
最后的结果是:
np.unique(np.array(list(zip(*np.nonzero(l == [0,0]))))[:,0:1])
> array([0, 4], dtype=int64)
请注意,正如评论者所建议的那样,解决方案在列表理解格式中可能更具可读性:
#We avoid using an ndarray since we will be using vanilla python anyway
l = [(0, 0), (1, 1), (3, 4), (3, 6), (0,0)]
l_indx = [i for i, tup in enumerate(l) if tup == (0,0)]
> [0, 4]
TL;DR
为了使用矢量化代码实现您所追求的目标,您需要将 l
(或 (0, 0)
)转换为 NumPy 数组并使用 np.all(..., axis=1)
:
import numpy as np
l = [(0, 0), (1, 1), (3, 4)]
print(np.nonzero(np.all(np.array(l) == (0, 0), axis=1))[0])
# [0]
您可能应该退后一步,尝试理解您的代码的各个部分。
- 您的输入是
list
。列表上的等于运算符==
作用于整个列表,根据两个列表是否包含相同的元素,您将得到True
或False
。 np.where()
没有更多选项(应避免使用np.nonzero()
)将查找数组中的 non-zero 条目和 return 这些条目的索引.
因此,标量布尔值上的 np.where()
将 return (array([0]),)
或 (array([], dtype=int64),)
分别对应 True
或 False
。
如果您要在 NumPy 数组上使用 ==
,结果仍然会与您预期的不同,主要是因为 ==
在这种情况下会起作用 element-wise(最终广播论据)。例如:
a = np.array([[0, 0], [1, 1], [2, 2]])
b = np.array([0, 1])
a == b
# [[ True False]
# [False True]
# [False False]]
要使用矢量化代码获得您想要的内容,您需要找到值为 所有 True
的行,实际上 collapsing/reducing 一个 dimension/axis 与 np.all(..., axis=...)
,例如:
a = np.array([[0, 0], [1, 1], [2, 2]])
c = np.array([0, 0])
np.all(a == c, axis=1)
# [ True, False, False]
这基本上会检查轴 1 中的所有元素是否为 True
。
现在,您可以使用 np.nonzero()
查找索引:
np.nonzero(np.all(a == c, axis=1))[0]
# [0]
如果您不需要使用 NumPy,更简单的方法是循环遍历列表。
假设输入为:
l = [(0, 0), (1, 1), (3, 4)] * 2
x = (0, 0)
一个这样的实现可以是:
def search_all(items, x):
for i, item in enumerate(items):
if item == x:
yield i
print(list(search_all(l, x))
# [0, 3]
如果列表足够稀疏,使用 list.index()
方法循环遍历列表可能会更快,这会强制某些迭代通过更快的路径。
这基本上是在 FlyingCircus index_all()
.
def index_all(seq, item):
i = 0
try:
while True:
i = seq.index(item, i)
yield i
i += 1
except ValueError:
return
print(list(index_all(l, x))
# [0, 3]
以上两种方法都非常节省内存,而且速度相当快。 为了更快地实现,可以使用 Cython 来加速显式循环。
正如@KellyBundy 所建议的,甚至可以隐式地执行所有循环:
import itertools
import operator
def search_all_iter(seq, item):
return itertools.compress(
itertools.count(),
map(operator.eq, seq, itertools.repeat(item)))
print(list(search_all_iter(l, x))
# [0, 3]
如果输入是 NumPy 数组(因此不需要转换为 NumPy 数组),我们可以从 al
/ax
:
import numpy as np
al = np.array(l)
ax = np.array(x)
然后可以设计矢量化方法(但是需要相当大的临时对象),例如 find_index_np()
(借鉴自 find_index_unique()
(borrowed from
def find_index_np(arr, subarr):
return np.nonzero(np.all(arr == subarr, axis=1))[0]
print(find_index_np(al, ax))
# [0 3]
def find_index_unique(arr, subarr):
return np.unique(np.array(list(zip(*np.nonzero(arr == subarr))))[:, :1])
print(find_index_unique(al, ax))
# [0 3]
或者,可以使用 Numba 加速编写与 NumPy 数组配合良好的内存高效且快速的方法:
@nb.njit
def all_eq(a, b):
for x, y in zip(a, b):
if x != y:
return False
return True
@nb.njit
def find_index_nb(arr, subarr):
result = np.empty(arr.size, dtype=np.int_)
j = 0
for i, x in enumerate(arr):
if all_eq(x, subarr):
result[j] = i
j += 1
return result[:j].copy()
print(find_index_np(al, ax))
# [0 3]
最后,请为以上找到一些基准:
l = ([(0, 0)] + [(1, 1)] * 9) * 1000
x = (0, 0)
al = np.array(l)
ax = np.array(x)
%timeit list(index_all(l, x))
# 1000 loops, best of 5: 500 µs per loop
%timeit list(search_all(l, x))
# 1000 loops, best of 5: 911 µs per loop
%timeit list(search_all_iter(l, x))
# 1000 loops, best of 5: 701 µs per loop
%timeit find_index_np(al, ax)
# 1000 loops, best of 5: 247 µs per loop
%timeit find_index_unique(al, ax)
# 1000 loops, best of 5: 1.41 ms per loop
%timeit find_index_nb(al, ax)
# 1000 loops, best of 5: 465 µs per loop
请注意,所有这些方法的时间复杂度都是 O(N)
(它们的计算时间随着输入的大小线性增加),但它们的相对速度在很大程度上也取决于找到被搜索项目的频率,例如如果我们将 x
/ax
的出现次数从每十次增加到每两次出现一次,那么 index_all()
就会变得比 search_all()
:
l = ([(0, 0)] + [(1, 1)]) * 5000
x = (0, 0)
al = np.array(l)
ax = np.array(x)
%timeit list(index_all(l, x))
# 1000 loops, best of 5: 1.22 ms per loop
%timeit list(search_all(l, x))
# 1000 loops, best of 5: 1.01 ms per loop
%timeit list(search_all_iter(l, x))
# 1000 loops, best of 5: 666 µs per loop
%timeit find_index_np(al, ax)
# 1000 loops, best of 5: 250 µs per loop
%timeit find_index_unique(al, ax)
# 100 loops, best of 5: 6.91 ms per loop
%timeit find_index_nb(al, ax)
# 1000 loops, best of 5: 483 µs per loop