获取 m 个列表的 r 长度元组组合,任何列表中不超过一个元素,并且 r < m

Get r-length tuple combinations of m lists, with no more than a single element from any list, and r < m

在下面的示例中,我有 m = 3 个列表,我计算了大小 r = 2 的组合。

import itertools

a = ['a1', 'a2', 'a3']
b = ['b1', 'b2', 'b3']
c = ['c1', 'c2', 'c3']

print(list(itertools.combinations(itertools.chain(a, b, c), 2)))

输出:

[('a1', 'a2'), ('a1', 'a3'), ('a1', 'b1'), ('a1', 'b2'), ('a1', 'b3'), ('a1', 'c1'), ('a1', 'c2'), ('a1', 'c3'), ('a2', 'a3'), ('a2', 'b1'), ('a2', 'b2'), ('a2', 'b3'), ('a2', 'c1'), ('a2', 'c2'), ('a2', 'c3'), ('a3', 'b1'), ('a3', 'b2'), ('a3', 'b3'), ('a3', 'c1'), ('a3', 'c2'), ('a3', 'c3'), ('b1', 'b2'), ('b1', 'b3'), ('b1', 'c1'), ('b1', 'c2'), ('b1', 'c3'), ('b2', 'b3'), ('b2', 'c1'), ('b2', 'c2'), ('b2', 'c3'), ('b3', 'c1'), ('b3', 'c2'), ('b3', 'c3'), ('c1', 'c2'), ('c1', 'c3'), ('c2', 'c3')]

问题:

我不想要来自同一个列表的组合。例如,('a1', 'a2')('a1', 'a3') 应该被删除。

这可能是一个绕过的解决方案,但您可能希望修改作为输出获得的列表以仅保留所需的值。

[i for i in a if i[0][0]!=i[1][0]]

a 是您的列表 a = list(itertools.combinations(itertools.chain(a,b,c), 2))

过滤 itertools.combinations

的输出

这不一定是最优雅或最高效的解决方案,但它确实有效。

def is_unique(comb, lookup):
    xs = [lookup[x] for x in comb]
    return len(xs) == len(set(xs))


def combine(args, r):
    lookup = {x: i for i, xs in enumerate(args) for x in xs}
    return (
        comb
        for comb in itertools.combinations(itertools.chain(*args), r) 
        if is_unique(comb, lookup)
    )


>>> list(combine((a, b, c), 2))
[('a1', 'b1'), ('a1', 'b2'), ('a1', 'b3'),      # (a1, b*)
 ('a1', 'c1'), ('a1', 'c2'), ('a1', 'c3'),      # (a1, c*)
 ('a2', 'b1'), ('a2', 'b2'), ('a2', 'b3'),      # (a2, b*)
 ('a2', 'c1'), ('a2', 'c2'), ('a2', 'c3'),      # (a2, c*)
 ('a3', 'b1'), ('a3', 'b2'), ('a3', 'b3'),      # (a3, b*)
 ('a3', 'c1'), ('a3', 'c2'), ('a3', 'c3'),      # (a3, c*)
 ('b1', 'c1'), ('b1', 'c2'), ('b1', 'c3'),      # (b1, c*)
 ('b2', 'c1'), ('b2', 'c2'), ('b2', 'c3'),      # (b2, c*)
 ('b3', 'c1'), ('b3', 'c2'), ('b3', 'c3')]      # (b3, c*)

稍微解释一下 Mateen 的答案,您可以先在外循环中选择 rm choose r 方式,然后在内循环中迭代 r 组的乘积环形。输出的顺序可能与其他方法不同。

lsts = [a, b, c]
def f(lsts, r):
  """
  lsts :: [[a]]
  r :: Integer
  Generate r-tuples with at most one element coming from
  each member of lsts.
  """
  m = len(lsts)
  assert m >= r
  for xs in itertools.combinations(lsts, r):
    for x in itertools.product(*xs):
      yield x