列表列表列表之间的交集长度

Length of the intersections between a list an list of list

注意:几乎是

的重复

差异:

x = [500 numbers between 1 and N]
y = [[1, 2, 3], [4, 5, 6, 7], [8, 9], [10, 11, 12], etc. up to N]

这里有一些假设:

我的目的是要知道,在y的子列表中,与x的子列表至少有10个交集。

我显然可以循环 :

def find_best(x, y):
    result = []

    for index, sublist in enumerate(y):
        intersection = set(x).intersection(set(sublist))
        if len(intersection) > 2:  # in real live: > 10
            result.append(index)

    return(result)


x = [1, 2, 3, 4, 5, 6]
y = [[1, 2, 3], [4],  [5, 6], [7], [8, 9, 10, 11]]

res = find_best(x, y)
print(res)   # [0, 2]

这里的结果是 [0,2] 因为 y 的第一个和第三个子列表有 2 个元素与 x.

相交

另一种方法应该只解析一次 y 并计算交叉点:

def find_intersec2(x, y):
    n_sublists = len(y)
    res = {num: 0 for num in range(0, n_sublists + 1)}
    for list_no, sublist in enumerate(y):
        for num in sublist:
            if num in x:
                x.remove(num)
                res[list_no] += 1
    return [n for n in range(n_sublists + 1) if res[n] >= 2]

第二种方法使用了更多的假设。

问题:

由于y包含不相交的范围并且它们的并集也是一个范围,一个非常快速的解决方案是首先对[=14=进行二分查找] 然后计算结果索引,并且只计算 return 出现至少 10 次的索引。该算法的复杂度为 O(Nx log Ny),其中 NxNy 项的数量分别为 xy。该算法接近最佳(因为x需要完整阅读)。


实际实施

首先,您需要将当前的 y 转换为包含所有范围的起始值(按递增顺序)的 Numpy 数组,最后一个值是 N(假设 N 被排除在 y 的范围之外,否则 N+1 )。这部分可以假定为免费的,因为 y 可以在您的情况下在编译时计算。这是一个例子:

import numpy as np
y = np.array([1, 4, 8, 10, 13, ..., N])

然后,您需要执行二进制搜索并检查值是否符合 y:

indices = np.searchsorted(y, x, 'right')

# The `0 < indices < len(y)` check should not be needed regarding the input.
# If so, you can use only `indices -= 1`.
indices = indices[(0 < indices) & (indices < len(y))] - 1

然后你需要计算索引并过滤至少 :

uniqueIndices, counts = np.unique(indices, return_counts=True)
result = uniqueIndices[counts >= 10]

这是一个基于您的示例:

x = np.array([1, 2, 3, 4, 5, 6])

# [[1, 2, 3], [4],  [5, 6], [7], [8, 9, 10, 11]]
y = np.array([1, 4, 5, 7, 8, 12])

# Actual simplified version of the above algorithm
indices = np.searchsorted(y, x, 'right') - 1
uniqueIndices, counts = np.unique(indices, return_counts=True)
result = uniqueIndices[counts >= 2]

# [0, 2]
print(result.tolist())

它在我的机器上运行不到 0.1 毫秒,根据您的输入限制随机输入。

把 y 变成 2 个字典。

index = { # index to count map
    0 : 0,
    1 : 0,
    2 : 0,
    3 : 0,
    4 : 0
}

y = { # elem to index map
    1: 0,
    2: 0,
    3: 0,
    4: 1,
    5: 2,
    6: 2,
    7: 3,
    8 : 4,
    9 : 4,
    10 : 4,
    11 : 4
}

既然你提前知道了y,上面的操作我就不计入时间复杂度了。然后,计算交集:

x = [1, 2, 3, 4, 5, 6]
for e in x: index[y[e]] += 1

既然你提到x很小,我试着让时间复杂度只取决于x的大小(在本例中是O(n))。

最后,答案是索引字典中的键列表,其中值 >= 2(或实际情况下为 10)。

answer = [i for i in index if index[i] >= 2]

这使用 y 创建一个线性数组,将每个 int 映射到(1 加),即 int 所在的范围或子组的索引;称为 x2range_counter.

x2range_counter使用32位的array.array类型来节省内存,可以缓存起来用于计算all x相同 y.

计算每个范围内特定 x 的命中率只是 count'er in function count_ranges` 的间接数组递增。

y = [[1, 2], [3, 4, 5], [6, 7, 8], [9, 10, 11, 12]]
x = [5, 3, 1, 11, 8, 10]

range_counter_max = len(y)
extent = y[-1][-1] + 1  # min in y must be 1 not 0 remember.
x2range_counter = array.array('L', [0] * extent)  # efficient 32 bit array storage

# Map any int in any x to appropriate ranges counter.
for range_counter_index, rng in enumerate(y, start=1):
    for n in rng:
        x2range_counter[n] = range_counter_index
print(x2range_counter)  # array('L', [0, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 4])

# x2range_counter can be saved for this y and any x on this y.

def count_ranges(x: List[int]) -> List[int]:
    "Number of x-hits on each y subgroup in order"
    # Note: count[0] initially catches errors. count[1..] counts x's in y ranges [0..]
    count = array.array('L', [0] * (range_counter_max + 1))
    for xx in x:
        count[x2range_counter[xx]] += 1
    assert count[0] == 0, "x values must all exist in a y range and y must have all int in its range."

    return count[1:] 

print(count_ranges(x))  # array('L', [1, 2, 1, 2])

我为此创建了一个 class,具有额外的功能,例如返回范围而不是索引;所有范围命中 >=M 次; (range, hit-count) 元组最先排序。

不同 x 的范围计算与 x 成正比,并且是简单的数组查找,而不是字典的任何散列。

你怎么看?