从没有重复的元组列表中获取具有相同 N 个交集的所有元组组的最快算法

Fastest algorithm to get all the tuple groups that has the same N intersections from a list of tuples without duplicates

我有一个包含 100 个元组的列表。每个元组包含 5 个唯一的整数。我想知道找到具有完全相同的 N = 2 个交叉点的所有组的最快方法。如果一个元组有多对元素与其他元组有 2 个交集,则找到所有这些元素并存储在不同的组中。预期的输出是一个唯一列表的列表([(1,2,3,4,5),(4,5,6,7,8)][(4,5,6,7,8),(1,2,3,4,5)] 相同),其中每个列表都是一个组,其中包含具有相同 N=2 个交集的所有元组。下面是我的代码:

from collections import defaultdict
from random import sample, choice

lst =  [tuple(sample(range(10), 5)) for _ in range(100)]

dct = defaultdict(list)
N = 2
for i in lst:
    for j in lst:
        if len(set(i).intersection(set(j))) == N:
            dct[i].append(j)
key = choice(list(dct))
print([key] + dct[key])
>>> [(4, 5, 2, 3, 7), (4, 6, 2, 5, 0), (9, 4, 2, 1, 8), (7, 6, 5, 2, 0), (2, 4, 0, 7, 8)]

显然,所有后 4 个元组与第一个元组都有 2 个交集,但不一定是相同的 2 个元素。那么我应该如何获得具有相同2个交集的元组呢?

一个明显的解决方案是暴力枚举所有可能的 (x, y) 整数对和相应地具有此 (x, y) 交集的组元组,但是有没有更快的算法来做到这一点?

编辑:[(1, 2, 3, 4, 5), (4, 5, 6, 7, 8), (4, 5, 9, 10, 11)]允许在同一组,但[(1, 2, 3, 4, 5),(4, 5, 6, 7, 8), (4, 5, 6, 10, 11)]不可以,因为(4, 5, 6, 7, 8)(4, 5, 6, 10, 11)有3个交集。在这种情况下,它应该被分成2组[(1, 2, 3, 4, 5), (4, 5, 6, 7, 8)][(1, 2, 3, 4, 5), (4, 5, 6, 10, 11)]。最终结果当然会包含各种大小的组,包括许多只有两个元组的短列表,但这就是我想要的。

一个简单的基于组合的方法就足够了:

from collections import defaultdict
from itertools import combinations

res = defaultdict(set)
for t1, t2 in combinations(tuples, 2):
    overlap = set(t1) & set(t2)
    if len(overlap) == 2:
        cur = res[frozenset(overlap)]
        cur.add(t1)
        cur.add(t2)

结果:

defaultdict(set,
            {frozenset({2, 4}): {(2, 4, 0, 7, 8),
              (4, 5, 2, 2, 4),
              (4, 6, 2, 6, 0),
              (8, 4, 2, 1, 8)},
             frozenset({2, 5}): {(4, 5, 2, 2, 4), (7, 6, 5, 2, 0)}})

我喜欢@acushner 的解决方案看起来多么干净,但我写了一个快得多的解决方案:

def all_n_intersections2(xss, n):
    xss = [frozenset(xs) for xs in xss]
    result = {}
    while xss:
        xsa = xss.pop()
        for xsb in xss:
            ixs = xsa.intersection(xsb)
            if len(ixs) == n:
                if ixs not in result:
                    result[ixs] = [xsa, xsb]
                else:
                    result[ixs].append(xsb)
    return result

如果我让他们互相攻击:

from timeit import timeit
from random import sample

from collections import defaultdict
from itertools import combinations


def all_n_intersections1(xss, n):
    res = defaultdict(set)
    for t1, t2 in combinations(xss, 2):
        overlap = set(t1) & set(t2)
        if len(overlap) == n:
            cur = res[frozenset(overlap)]
            cur.add(t1)
            cur.add(t2)


def all_n_intersections2(xss, n):
    xss = [frozenset(xs) for xs in xss]
    result = {}
    while xss:
        xsa = xss.pop()
        for xsb in xss:
            ixs = xsa.intersection(xsb)
            if len(ixs) == n:
                if ixs not in result:
                    result[ixs] = [xsa, xsb]
                else:
                    result[ixs].append(xsb)
    return result


data = [tuple(sample(range(10), 5)) for _ in range(100)]

print(timeit(lambda: all_n_intersections1(data, 2), number=1000))
print(timeit(lambda: all_n_intersections2(data, 2), number=1000))

结果:

3.4294801999999995
1.4871790999999999

加上一些评论:

def all_n_intersections2(xss, n):
    # using frozensets to be able to use them as dict keys, convert only once
    xss = [frozenset(xs) for xs in xss]
    result = {}
    # keep going until there are no more items left to combine
    while xss:
        # popping to compare against all others remaining, intersect each pair only once
        xsa = xss.pop()
        for xsb in xss:
            # using library intersection, assuming the native implementation is fastest
            ixs = xsa.intersection(xsb)
            if len(ixs) == n:
                if ixs not in result:
                    # not using default dict, initialising with these two
                    result[ixs] = [xsa, xsb]
                else:
                    # otherwise, xsa was already in there, appending xsb
                    result[ixs].append(xsb)
    return result

解决方案的作用:

  • 对于 xsaxsb 来自 xss 的每个组合,它计算交集
  • 如果交集ixs是目标长度nxsaxsb被添加到字典中的列表中,使用ixs作为键
  • 避免重复附加(除非源数据中有重复的元组)