如何删除 numpy 数组中的非对称对?

How to remove non-symmetric pairs in a numpy array?

给定一个 numpy Nx2 numpy 数组 data 的整数(我们可以假设 data 没有重复的行),我只需要保留元素满足关系的行

(data[i,0] == data[j,1]) & (data[i,1] == data[j,0])

例如

import numpy as np
data = np.array([[1, 2], 
                 [2, 1],
                 [7, 3],
                 [6, 6],
                 [5, 6]])

我应该return

array([[1, 2], # because 2,1 is present
       [2, 1], # because 1,2 is present
       [6, 6]]) # because 6,6 is present

一种详细的方法是

def filter_symmetric_pairs(data):
  result = np.empty((0,2))
  for i in range(len(data)):
    for j in range(len(data)):
      if (data[i,0] == data[j,1]) & (data[i,1] == data[j,0]):
        result = np.vstack([result, data[i,:]])
  return result

我想出了一个更简洁的:

def filter_symmetric_pairs(data):
  return data[[row.tolist() in data[:,::-1].tolist() for row in data]]

有人可以推荐一个更好的 numpy 习语吗?

您可以使用以下几种不同的方法来完成此操作。第一个是 "obvious" 二次解,它很简单,但如果输入数组很大,可能会给您带来麻烦。只要输入中的数字范围不是很大,第二个就应该可以工作,并且它具有使用线性内存量的优势。

import numpy as np

# Input data
data = np.array([[1, 2],
                 [2, 1],
                 [7, 3],
                 [6, 6],
                 [5, 6]])

# Method 1 (quadratic memory)
d0, d1 = data[:, 0, np.newaxis], data[:, 1]
# Compare all values in first column to all values in second column
c = d0 == d1
# Find where comparison matches both ways
c &= c.T
# Get matching elements
res = data[c.any(0)]
print(res)
# [[1 2]
#  [2 1]
#  [6 6]]

# Method 2 (linear memory)
# Convert pairs into single values
# (assumes positive values, otherwise shift first)
n = data.max() + 1
v = data[:, 0] + (n * data[:, 1])
# Symmetric values
v2 = (n * data[:, 0]) + data[:, 1]
# Find where symmetric is present
m = np.isin(v2, v)
res = data[m]
print(res)
# [[1 2]
#  [2 1]
#  [6 6]]

您可以对原始数组和反向数组使用 argsort 对保留行内容的数组进行排序,然后只需检查哪些行相等并将其用作切片 data 的掩码。

import numpy as np

data = np.array([[1, 2],
                 [2, 1],
                 [7, 3],
                 [6, 6],
                 [5, 6]])

data_r = data[:,::-1]
sorter = data.argsort(axis=0)[:,0]
sorter_r = data_r.argsort(axis=0)[:,0]
mask = (data.take(sorter, axis=0) == data_r.take(sorter_r, axis=0)).all(axis=1)

data[mask]

# returns:
array([[1, 2],
       [2, 1],
       [6, 6]])

我想到了另一个解决方案,它将 data 视为有向图的边列表并仅过滤双向边(因此我的问题等同于 ):

def filter_symmetric_pairs(data):
  rank = max(data.flatten() + 1) 
  adj = np.zeros((rank, rank)) 
  adj[data[:,0], data[:,1]] = 1 # treat the coordinates as edges of directed graph, compute adjaciency matrix
  bidirected_edges = (adj == adj.T) & (adj == 1) # impose symmetry and a nonzero value
  return np.vstack(np.nonzero(bidirected_edges)).T # list indices of components satisfying the above constraint