np.isin - 测试 Numpy 数组是否包含考虑顺序的给定行
np.isin - testing whether a Numpy array contains a given row considering the order
我正在使用以下行查找 b
的行是否在 a
中
a[np.all(np.isin(a[:, 0:3], b[:, 0:3]), axis=1), 3]
数组沿 axis=1
有更多条目,我只比较前 3 个条目和 return a
的第四个条目 (idx=3)。
我意识到的可能错误是,没有考虑条目的顺序。因此,a
和 b
的以下示例:
a = np.array([[...],
[1, 2, 3, 1000],
[2, 1, 3, 2000],
[...]])
b = np.array([[1, 2, 3]])
会 return [1000, 2000]
而不是只有 [1000]
.
如何考虑行的顺序?
对于小 b
(少于 100 行),试试这个:
a[(a[:, :3] == b[:, None]).all(axis=-1).any(axis=0)]
示例:
a = np.array([[1, 0, 5, 0],
[1, 2, 3, 1000],
[2, 1, 3, 2000],
[0, 0, 1, 1]])
b = np.array([[1, 2, 3], [0, 0, 1]])
>>> a[(a[:, :3] == b[:, None]).all(axis=-1).any(axis=0), 3]
array([1000, 1])
解释:
关键是将 a
的所有行(前 3 列)的相等性测试“分发”到 b
的所有行:
# on the example above
>>> a[:, :3] == b[:, None]
array([[[ True, False, False],
[ True, True, True], # <-- a[1,:3] matches b[0]
[False, False, True],
[False, False, False]],
[[False, True, False],
[False, False, False],
[False, False, False],
[ True, True, True]]]) # <-- a[3, :3] matches b[1]
请注意,这可能很大:形状为 (len(b), len(a), 3)
。
那么第一个 .all(axis=-1)
意味着我们希望所有整行都匹配:
>>> (a[:, :3] == b[:, None]).all(axis=-1)
array([[False, True, False, False],
[False, False, False, True]])
最后一位 .any(axis=0)
表示:“匹配 b
中的任何行”:
>>> (a[:, :3] == b[:, None]).all(axis=-1).any(axis=0)
array([False, True, False, True])
即:“a[2, :3]
匹配 一些 行 b
以及 a[3, :3]
”。
最后,在 a
中使用它作为掩码并取第 3 列。
性能说明
上述技术将 a 的行乘积分配给 b 的行。如果 a
和 b
都有很多行,这可能会很慢并且会使用大量内存。
作为替代方案,您可以在纯 Python 中使用 set
成员资格(不对列进行子集化——这可以由调用者完成):
def py_rows_in(a, b):
z = set(map(tuple, b))
return [row in z for row in map(tuple, a)]
当b
超过50~100行时,那么这个可能会更快,相比上面的np
版本,这里写成一个函数:
def np_rows_in(a, b):
return (a == b[:, None]).all(axis=-1).any(axis=0)
import perfplot
fig, axes = plt.subplots(ncols=2, figsize=(16, 5))
plt.subplots_adjust(wspace=.5)
for ax, alen in zip(axes, [100, 10_000]):
a = np.random.randint(0, 20, (alen, 4))
plt.sca(ax)
ax.set_title(f'a: {a.shape[0]:_} rows')
perfplot.show(
setup=lambda n: np.random.randint(0, 20, (n, 3)),
kernels=[
lambda b: np_rows_in(a[:, :3], b),
lambda b: py_rows_in(a[:, :3], b),
],
labels=['np_rows_in', 'py_rows_in'],
n_range=[2 ** k for k in range(10)],
xlabel='len(b)',
)
plt.show()
我正在使用以下行查找 b
的行是否在 a
a[np.all(np.isin(a[:, 0:3], b[:, 0:3]), axis=1), 3]
数组沿 axis=1
有更多条目,我只比较前 3 个条目和 return a
的第四个条目 (idx=3)。
我意识到的可能错误是,没有考虑条目的顺序。因此,a
和 b
的以下示例:
a = np.array([[...],
[1, 2, 3, 1000],
[2, 1, 3, 2000],
[...]])
b = np.array([[1, 2, 3]])
会 return [1000, 2000]
而不是只有 [1000]
.
如何考虑行的顺序?
对于小 b
(少于 100 行),试试这个:
a[(a[:, :3] == b[:, None]).all(axis=-1).any(axis=0)]
示例:
a = np.array([[1, 0, 5, 0],
[1, 2, 3, 1000],
[2, 1, 3, 2000],
[0, 0, 1, 1]])
b = np.array([[1, 2, 3], [0, 0, 1]])
>>> a[(a[:, :3] == b[:, None]).all(axis=-1).any(axis=0), 3]
array([1000, 1])
解释:
关键是将 a
的所有行(前 3 列)的相等性测试“分发”到 b
的所有行:
# on the example above
>>> a[:, :3] == b[:, None]
array([[[ True, False, False],
[ True, True, True], # <-- a[1,:3] matches b[0]
[False, False, True],
[False, False, False]],
[[False, True, False],
[False, False, False],
[False, False, False],
[ True, True, True]]]) # <-- a[3, :3] matches b[1]
请注意,这可能很大:形状为 (len(b), len(a), 3)
。
那么第一个 .all(axis=-1)
意味着我们希望所有整行都匹配:
>>> (a[:, :3] == b[:, None]).all(axis=-1)
array([[False, True, False, False],
[False, False, False, True]])
最后一位 .any(axis=0)
表示:“匹配 b
中的任何行”:
>>> (a[:, :3] == b[:, None]).all(axis=-1).any(axis=0)
array([False, True, False, True])
即:“a[2, :3]
匹配 一些 行 b
以及 a[3, :3]
”。
最后,在 a
中使用它作为掩码并取第 3 列。
性能说明
上述技术将 a 的行乘积分配给 b 的行。如果 a
和 b
都有很多行,这可能会很慢并且会使用大量内存。
作为替代方案,您可以在纯 Python 中使用 set
成员资格(不对列进行子集化——这可以由调用者完成):
def py_rows_in(a, b):
z = set(map(tuple, b))
return [row in z for row in map(tuple, a)]
当b
超过50~100行时,那么这个可能会更快,相比上面的np
版本,这里写成一个函数:
def np_rows_in(a, b):
return (a == b[:, None]).all(axis=-1).any(axis=0)
import perfplot
fig, axes = plt.subplots(ncols=2, figsize=(16, 5))
plt.subplots_adjust(wspace=.5)
for ax, alen in zip(axes, [100, 10_000]):
a = np.random.randint(0, 20, (alen, 4))
plt.sca(ax)
ax.set_title(f'a: {a.shape[0]:_} rows')
perfplot.show(
setup=lambda n: np.random.randint(0, 20, (n, 3)),
kernels=[
lambda b: np_rows_in(a[:, :3], b),
lambda b: py_rows_in(a[:, :3], b),
],
labels=['np_rows_in', 'py_rows_in'],
n_range=[2 ** k for k in range(10)],
xlabel='len(b)',
)
plt.show()