如何删除 numpy 数组中的非对称对?
How to remove non-symmetric pairs in a numpy array?
给定一个 numpy Nx2 numpy 数组 data
的整数(我们可以假设 data
没有重复的行),我只需要保留元素满足关系的行
(data[i,0] == data[j,1]) & (data[i,1] == data[j,0])
例如
import numpy as np
data = np.array([[1, 2],
[2, 1],
[7, 3],
[6, 6],
[5, 6]])
我应该return
array([[1, 2], # because 2,1 is present
[2, 1], # because 1,2 is present
[6, 6]]) # because 6,6 is present
一种详细的方法是
def filter_symmetric_pairs(data):
result = np.empty((0,2))
for i in range(len(data)):
for j in range(len(data)):
if (data[i,0] == data[j,1]) & (data[i,1] == data[j,0]):
result = np.vstack([result, data[i,:]])
return result
我想出了一个更简洁的:
def filter_symmetric_pairs(data):
return data[[row.tolist() in data[:,::-1].tolist() for row in data]]
有人可以推荐一个更好的 numpy 习语吗?
您可以使用以下几种不同的方法来完成此操作。第一个是 "obvious" 二次解,它很简单,但如果输入数组很大,可能会给您带来麻烦。只要输入中的数字范围不是很大,第二个就应该可以工作,并且它具有使用线性内存量的优势。
import numpy as np
# Input data
data = np.array([[1, 2],
[2, 1],
[7, 3],
[6, 6],
[5, 6]])
# Method 1 (quadratic memory)
d0, d1 = data[:, 0, np.newaxis], data[:, 1]
# Compare all values in first column to all values in second column
c = d0 == d1
# Find where comparison matches both ways
c &= c.T
# Get matching elements
res = data[c.any(0)]
print(res)
# [[1 2]
# [2 1]
# [6 6]]
# Method 2 (linear memory)
# Convert pairs into single values
# (assumes positive values, otherwise shift first)
n = data.max() + 1
v = data[:, 0] + (n * data[:, 1])
# Symmetric values
v2 = (n * data[:, 0]) + data[:, 1]
# Find where symmetric is present
m = np.isin(v2, v)
res = data[m]
print(res)
# [[1 2]
# [2 1]
# [6 6]]
您可以对原始数组和反向数组使用 argsort
对保留行内容的数组进行排序,然后只需检查哪些行相等并将其用作切片 data
的掩码。
import numpy as np
data = np.array([[1, 2],
[2, 1],
[7, 3],
[6, 6],
[5, 6]])
data_r = data[:,::-1]
sorter = data.argsort(axis=0)[:,0]
sorter_r = data_r.argsort(axis=0)[:,0]
mask = (data.take(sorter, axis=0) == data_r.take(sorter_r, axis=0)).all(axis=1)
data[mask]
# returns:
array([[1, 2],
[2, 1],
[6, 6]])
我想到了另一个解决方案,它将 data
视为有向图的边列表并仅过滤双向边(因此我的问题等同于 ):
def filter_symmetric_pairs(data):
rank = max(data.flatten() + 1)
adj = np.zeros((rank, rank))
adj[data[:,0], data[:,1]] = 1 # treat the coordinates as edges of directed graph, compute adjaciency matrix
bidirected_edges = (adj == adj.T) & (adj == 1) # impose symmetry and a nonzero value
return np.vstack(np.nonzero(bidirected_edges)).T # list indices of components satisfying the above constraint
给定一个 numpy Nx2 numpy 数组 data
的整数(我们可以假设 data
没有重复的行),我只需要保留元素满足关系的行
(data[i,0] == data[j,1]) & (data[i,1] == data[j,0])
例如
import numpy as np
data = np.array([[1, 2],
[2, 1],
[7, 3],
[6, 6],
[5, 6]])
我应该return
array([[1, 2], # because 2,1 is present
[2, 1], # because 1,2 is present
[6, 6]]) # because 6,6 is present
一种详细的方法是
def filter_symmetric_pairs(data):
result = np.empty((0,2))
for i in range(len(data)):
for j in range(len(data)):
if (data[i,0] == data[j,1]) & (data[i,1] == data[j,0]):
result = np.vstack([result, data[i,:]])
return result
我想出了一个更简洁的:
def filter_symmetric_pairs(data):
return data[[row.tolist() in data[:,::-1].tolist() for row in data]]
有人可以推荐一个更好的 numpy 习语吗?
您可以使用以下几种不同的方法来完成此操作。第一个是 "obvious" 二次解,它很简单,但如果输入数组很大,可能会给您带来麻烦。只要输入中的数字范围不是很大,第二个就应该可以工作,并且它具有使用线性内存量的优势。
import numpy as np
# Input data
data = np.array([[1, 2],
[2, 1],
[7, 3],
[6, 6],
[5, 6]])
# Method 1 (quadratic memory)
d0, d1 = data[:, 0, np.newaxis], data[:, 1]
# Compare all values in first column to all values in second column
c = d0 == d1
# Find where comparison matches both ways
c &= c.T
# Get matching elements
res = data[c.any(0)]
print(res)
# [[1 2]
# [2 1]
# [6 6]]
# Method 2 (linear memory)
# Convert pairs into single values
# (assumes positive values, otherwise shift first)
n = data.max() + 1
v = data[:, 0] + (n * data[:, 1])
# Symmetric values
v2 = (n * data[:, 0]) + data[:, 1]
# Find where symmetric is present
m = np.isin(v2, v)
res = data[m]
print(res)
# [[1 2]
# [2 1]
# [6 6]]
您可以对原始数组和反向数组使用 argsort
对保留行内容的数组进行排序,然后只需检查哪些行相等并将其用作切片 data
的掩码。
import numpy as np
data = np.array([[1, 2],
[2, 1],
[7, 3],
[6, 6],
[5, 6]])
data_r = data[:,::-1]
sorter = data.argsort(axis=0)[:,0]
sorter_r = data_r.argsort(axis=0)[:,0]
mask = (data.take(sorter, axis=0) == data_r.take(sorter_r, axis=0)).all(axis=1)
data[mask]
# returns:
array([[1, 2],
[2, 1],
[6, 6]])
我想到了另一个解决方案,它将 data
视为有向图的边列表并仅过滤双向边(因此我的问题等同于
def filter_symmetric_pairs(data):
rank = max(data.flatten() + 1)
adj = np.zeros((rank, rank))
adj[data[:,0], data[:,1]] = 1 # treat the coordinates as edges of directed graph, compute adjaciency matrix
bidirected_edges = (adj == adj.T) & (adj == 1) # impose symmetry and a nonzero value
return np.vstack(np.nonzero(bidirected_edges)).T # list indices of components satisfying the above constraint