非负整数组合的高效枚举

Efficient enumeration of non-negative integer composition

我想写一个函数my_func(n,l),对于一些正整数n,有效地枚举长度为l的有序非负整数组合*(其中l 大于 n)。比如我要my_func(2,3)到return[[0,0,2],[0,2,0],[2,0,0],[1,1,0],[1,0,1],[0,1,1]].

我最初的想法是将现有代码用于正整数分区(例如 accel_asc() 来自 this post),将正整数分区扩展为几个零和 return 所有排列。

def my_func(n, l):
    for ip in accel_asc(n):
        nic = numpy.zeros(l, dtype=int)
        nic[:len(ip)] = ip
        for p in itertools.permutations(nic):
            yield p

这个函数的输出是错误的,因为每个非负整数组合中一个数出现两次(或多次)的数在my_func的输出中出现了多次。例如,list(my_func(2,3)) returns [(1, 1, 0), (1, 0, 1), (1, 1, 0), (1, 0, 1), (0, 1, 1), (0, 1, 1), (2, 0, 0), (2, 0, 0), (0, 2, 0), (0, 0, 2), (0, 2, 0), (0, 0, 2)].

我可以通过生成所有非负整数组合的列表,删除重复的条目,然后 returning 剩余的列表(而不是生成器)来纠正这个问题。但这似乎效率低得令人难以置信,并且可能 运行 进入内存问题。解决此问题的更好方法是什么?

编辑

我快速比较了此 post 和 another post that cglacet has pointed out in the comments.

的答案中提供的解决方案

左边是 l=2*n,右边是 l=n+1。在这两种情况下,当 n<=5 时,user2357112 的第二个解决方案比其他解决方案更快。对于 n>5,user2357112、Nathan Verzemnieks 和 AndyP 提出的解决方案或多或少是相关的。但是当考虑 ln.

之间的其他关系时,结论可能会有所不同

........

*我最初要求的是非负整数partitions。 Joseph Wood 正确地指出我实际上是在寻找整数 组合 ,因为序列中数字的顺序对我很重要。

使用 stars and bars 概念:选择位置在 n 星之间放置 l-1 条,然后计算每个部分中有多少颗星:

import itertools

def diff(seq):
    return [seq[i+1] - seq[i] for i in range(len(seq)-1)]

def generator(n, l):
    for combination in itertools.combinations_with_replacement(range(n+1), l-1):
        yield [combination[0]] + diff(combination) + [n-combination[-1]]

我在这里使用 combinations_with_replacement 而不是 combinations,因此索引处理与您需要的 combinations 有点不同。带有 combinations 的代码将更接近标准的星条处理方式。


或者,另一种使用 combinations_with_replacement 的方法:从 l 个零的列表开始,从 l 个可能的位置中选择 n 个位置进行替换,然后添加1到每个选定的位置以产生输出:

def generator2(n, l):
    for combination in itertools.combinations_with_replacement(range(l), n):
        output = [0]*l
        for i in combination:
            output[i] += 1
        yield output

从一个简单的递归解决方案开始,它与你的问题相同:

def nn_partitions(n, l):
    if n == 0:
        yield [0] * l
    else:
        for part in nn_partitions(n - 1, l):
            for i in range(l):
                new = list(part)
                new[i] += 1
                yield new

也就是说,对于下一个较小数字的每个分区,对于该分区中的每个位置,将该位置的元素加1。它产生与您相同的重复项。不过,我记得一个解决类似问题的技巧:当您将 n 的分区 p 更改为 n+1 的分区时,将 p 的所有元素固定在 p 的左侧你增加的元素。也就是说,跟踪 p 的修改位置,永远不要修改 p 左侧的任何 "descendants"。这是相关代码:

def _nn_partitions(n, l):
    if n == 0:
        yield [0] * l, 0
    else:
        for part, start in _nn_partitions(n - 1, l):
            for i in range(start, l):
                new = list(part)
                new[i] += 1
                yield new, i

def nn_partitions(n, l):
    for part, _ in _nn_partitions(n, l):
        yield part

它非常相似 - 每一步都传递了额外的参数,所以我添加了包装器来为调用者删除它。

我没有对其进行广泛测试,但这似乎相当快 - nn_partitions(3, 5) 大约 35 微秒,nn_partitions(10, 20) 大约 18 秒(产生超过 2000 万个分区)。 (来自 的非常优雅的解决方案对于较小的情况大约需要两倍的时间,对于较大的情况大约需要四倍的时间。编辑:这是指该答案中的第一个解决方案;第二个比我的快在某些情况下,在其他情况下速度较慢。)