如何在 python 中实现一个简单的基于贪心多重集的算法

How to implement a simple greedy multiset based algorithm in python

我想实现以下算法。对于 nk,请按排序顺序考虑所有具有重复的组合,其中我们从 {0,..n-1} 中选择具有重复的 k 个数字。例如,如果 n=5k =3 我们有:

[(0, 0, 0), (0, 0, 1), (0, 0, 2), (0, 0, 3), (0, 0, 4), (0, 1, 1), (0, 1, 2), (0, 1, 3), (0, 1, 4), (0, 2, 2), (0, 2, 3), (0, 2, 4), (0, 3, 3), (0, 3, 4), (0, 4, 4), (1, 1, 1), (1, 1, 2), (1, 1, 3), (1, 1, 4), (1, 2, 2), (1, 2, 3), (1, 2, 4), (1, 3, 3), (1, 3, 4), (1, 4, 4), (2, 2, 2), (2, 2, 3), (2, 2, 4), (2, 3, 3), (2, 3, 4), (2, 4, 4), (3, 3, 3), (3, 3, 4), (3, 4, 4), (4, 4, 4)]

从现在开始,我会将每个组合都视为一个多重集。我想贪婪地浏览这些多重集并对列表进行分区。分区具有 属性,其中所有多重集的交集大小必须至少为 k-1。所以在这种情况下我们有:

(0, 0, 0), (0, 0, 1), (0, 0, 2), (0, 0, 3), (0, 0, 4)

然后

 (0, 1, 1), (0, 1, 2), (0, 1, 3), (0, 1, 4)

然后

(0, 2, 2), (0, 2, 3), (0, 2, 4)

然后

(0, 3,  3), (0, 3, 4)

然后

(0, 4, 4)

等等。

在 python 中,您可以按如下方式迭代组合:

import itertools
for multiset in itertools.combinations_with_replacement(range(5),3):
    #Greedy algo

How can I create these partitions?

我遇到的一个问题是如何计算多重集交集的大小。例如,多重集 (2,1,2)(3,2,2) 的交集大小为 2。


这是 n=4, k=4 的完整答案。

(0, 0, 0, 0), (0, 0, 0, 1), (0, 0, 0, 2), (0, 0, 0, 3)
(0, 0, 1, 1), (0, 0, 1, 2), (0, 0, 1, 3)
(0, 0, 2, 2), (0, 0, 2, 3)
(0, 0, 3, 3)
(0, 1, 1, 1), (0, 1, 1, 2), (0, 1, 1, 3)
(0, 1, 2, 2), (0, 1, 2, 3)
(0, 1, 3, 3)
(0, 2, 2, 2), (0, 2, 2, 3)
(0, 2, 3, 3), (0, 3, 3, 3)
(1, 1, 1, 1), (1, 1, 1, 2), (1, 1, 1, 3)
(1, 1, 2, 2), (1, 1, 2, 3)
(1, 1, 3, 3)
(1, 2, 2, 2), (1, 2, 2, 3)
(1, 2, 3, 3), (1, 3, 3, 3)
(2, 2, 2, 2), (2, 2, 2, 3)
(2, 2, 3, 3), (2, 3, 3, 3)
(3, 3, 3, 3)

创建分区的一种方法是遍历您的迭代器,然后将每个多重集*与前一个进行比较。我测试了 4 种方法**来比较多重集,我发现最快的是测试成员资格 in 先前多重集的迭代器,一旦成员资格测试失败,它就会被消耗并短路。如果多重集和前一个多重集中的相等项目数等于多重集的长度减 1,则满足对它们进行分组的条件。然后生成 lists 的结果输出生成器,其中 append 项满足前一个 list 的条件,并开始一个新的 list 包含 tuple 否则,yield一次一个地分组以最小化内存使用量:

import itertools

def f(n,k):
    prev, group = None, []
    for multiset in itertools.combinations_with_replacement(range(n),k):
        if prev:
            it = iter(prev)
            for idx, item in enumerate(multiset):
                if item not in it:
                    break
            if idx == len(multiset) - 1:
                group.append(multiset)
                continue
        if group:
            yield group
        group = [multiset]
        prev = multiset
    yield group

测试用例

输入:

for item in f(4,4):
    print(item)

输出:

[(0, 0, 0, 0), (0, 0, 0, 1), (0, 0, 0, 2), (0, 0, 0, 3)]
[(0, 0, 1, 1), (0, 0, 1, 2), (0, 0, 1, 3)]
[(0, 0, 2, 2), (0, 0, 2, 3)]
[(0, 0, 3, 3)]
[(0, 1, 1, 1), (0, 1, 1, 2), (0, 1, 1, 3)]
[(0, 1, 2, 2), (0, 1, 2, 3)]
[(0, 1, 3, 3)]
[(0, 2, 2, 2), (0, 2, 2, 3)]
[(0, 2, 3, 3), (0, 3, 3, 3)]
[(1, 1, 1, 1), (1, 1, 1, 2), (1, 1, 1, 3)]
[(1, 1, 2, 2), (1, 1, 2, 3)]
[(1, 1, 3, 3)]
[(1, 2, 2, 2), (1, 2, 2, 3)]
[(1, 2, 3, 3), (1, 3, 3, 3)]
[(2, 2, 2, 2), (2, 2, 2, 3)]
[(2, 2, 3, 3), (2, 3, 3, 3)]
[(3, 3, 3, 3)]

输入:

for item in f(5,3):
    print(item)

输出:

[(0, 0, 0), (0, 0, 1), (0, 0, 2), (0, 0, 3), (0, 0, 4)]
[(0, 1, 1), (0, 1, 2), (0, 1, 3), (0, 1, 4)]
[(0, 2, 2), (0, 2, 3), (0, 2, 4)]
[(0, 3, 3), (0, 3, 4)]
[(0, 4, 4)]
[(1, 1, 1), (1, 1, 2), (1, 1, 3), (1, 1, 4)]
[(1, 2, 2), (1, 2, 3), (1, 2, 4)]
[(1, 3, 3), (1, 3, 4)]
[(1, 4, 4)]
[(2, 2, 2), (2, 2, 3), (2, 2, 4)]
[(2, 3, 3), (2, 3, 4)]
[(2, 4, 4)]
[(3, 3, 3), (3, 3, 4)]
[(3, 4, 4), (4, 4, 4)]

* 我称它们为 multisets 以匹配您的术语,但它们实际上是 tuples(有序且不可变的数据结构);使用 collections.Counter 对象,例如 Counter((0, 0, 0, 1)) returns Counter({0: 3, 1: 1}),递减就像一个真正的多集方法,但我发现这比较慢,因为使用顺序实际上是有用。

** 提供与我测试的相同输出的其他较慢的函数:

def f2(n,k):
    prev, group = None, []
    for multiset in itertools.combinations_with_replacement(range(n),k):
        if prev:
            if sum(item1 == item2 for item1, item2 in zip(prev,multiset)) == len(multiset) - 1:
                group.append(multiset)
                continue
        if group:
            yield group
        group = [multiset]
        prev = multiset
    yield group

def f3(n,k):
    prev, group = None, []
    for multiset in itertools.combinations_with_replacement(range(n),k):
        if prev:
            lst = list(prev)
            for item in multiset:
                if item in lst:
                    lst.remove(item)
                else:
                    break
            if len(multiset) - len(lst) == len(multiset) - 1:
                group.append(multiset)
                continue
        if group:
            yield group
        group = [multiset]
        prev = multiset
    yield group

import collections
def f4(n,k):
    prev, group = None, []
    for multiset in itertools.combinations_with_replacement(range(n),k):
        if prev:
            if sum((collections.Counter(prev) - collections.Counter(multiset)).values()) == 1:
                group.append(multiset)
                continue
        if group:
            yield group
        group = [multiset]
        prev = multiset
    yield group

示例时间:

from timeit import timeit
list(f(11,10)) == list(f2(11,10)) == list(f3(11,10)) == list(f4(11,10))
# True
timeit(lambda: list(f(11,10)), number = 10)
# 4.19157001003623
timeit(lambda: list(f2(11,10)), number = 10)
# 7.32002648897469
timeit(lambda: list(f3(11,10)), number = 10)
# 6.236868146806955
timeit(lambda: list(f4(11,10)), number = 10)
# 47.20136355608702

请注意,由于生成的组合数量很大,因此对于较大的 nk 值,所有方法都会变慢。

我们可以查看要分割的 set/list 中的元组,其长度为 k,基数为 n。从数字来看,您的算法在最小数字优先的基础上是贪婪的。设以 k "digits" 和基数 n 的所有数字的集合表示为 N(k,n)。忽略 N(k,n) 不完全是你现在想要划分的列表这一事实,我们可以划分 N(k,n) 根据划分标准,贪婪地以最小的优先为基础;通过从 0 开始计数(例如,在 k=5 的情况下为 00000),并在我们计数时每次出现 carry 时创建一个新分区(即从数字 i 溢出到数字一+1)。 IE。规则是:进位 <=> new_partition.

证明:假设A是进位后的值,进位到第i-th位。 A 与进位前前一个分区中的所有数字共享一个公共前缀,但不包括 i-th,因此至少有 1 个不同. A 仅在 i 后与之前的另一个(较小)数字共享一个后缀,但该数字已经在与其他数字相差更多的分区中比 A 中的 1,所以 A 开始一个新分区。

但是,根据您的说明,我们只考虑N(k,n)的一个子集; X,其中对于 X 中的任何 x,当 i 时 x[i] <= x[j] > j.这给上面的 carry <=> new partition 规则增加了一些复杂性。现在:

  • new_partition => 进位
  • 但是进位并不一定意味着new_partition

只有一种情况进位不意味着new_partition:刚刚有一个进位创建了一个新的进位分区,然后还有另一个进位,由 x[i] <= x[j] when i > j 规则引起。下一个进位不会导致一个以上的变化,因此并不意味着一个新的分区。

实施:


class ExpNum:
  ''' Represents a number with base @base, @size digits, and funny successor semantics. '''
  def __init__(self, base, size):
    if size <= 0 or base <= 1:
      raise Exception("Bad args")
    self.size = size
    self.base = base
    self.number = [0]*size
    self.zero = [0]*size

  def increment(self):
    ''' Increment number by one. If we carry return index of carry else return -1. '''
    carried = -1
    for i in reversed(range(0, len(self.number))):
      self.number[i] = (self.number[i]+1)%self.base
      if self.number[i] != 0:
        break
      carried = i
    if carried >= 0:
      self.pullup()
    return carried

  def pullup(self):
    ''' Ensure x[i] <= x[j] when i > j '''
    for i in range(0, len(self.number)):
      if self.number[i] == 0 and i > 0:
        self.number[i] = self.number[i-1]

  def out_by_one_partition(self):
    ''' Do the partition by counting from 0 to n**k '''
    self.number = [0]*self.size
    just_carried = False
    partition = [list(self.number)]
    carried = self.increment()
    while self.number != self.zero:
      # Check for exception to carry => new partition.
      if carried >= 0 and not (just_carried and list(self.number)[carried] == (self.base -1) and len(partition) == 1):
        yield(partition)
        partition = []
      partition += [list(self.number)]
      just_carried = carried >= 0
      carried = self.increment()
    yield(partition)

测试:

from ExpNum import ExpNum
from timeit import timeit
from pprint import pprint
pprint(list(ExpNum(4,4).out_by_one_partition()))
print(timeit(lambda: list(ExpNum(11,10).out_by_one_partition()), number = 10))

测试结果:

[[[0, 0, 0, 0], [0, 0, 0, 1], [0, 0, 0, 2], [0, 0, 0, 3]],
 [[0, 0, 1, 1], [0, 0, 1, 2], [0, 0, 1, 3]],
 [[0, 0, 2, 2], [0, 0, 2, 3]],
 [[0, 0, 3, 3]],
 [[0, 1, 1, 1], [0, 1, 1, 2], [0, 1, 1, 3]],
 [[0, 1, 2, 2], [0, 1, 2, 3]],
 [[0, 1, 3, 3]],
 [[0, 2, 2, 2], [0, 2, 2, 3]],
 [[0, 2, 3, 3], [0, 3, 3, 3]],
 [[1, 1, 1, 1], [1, 1, 1, 2], [1, 1, 1, 3]],
 [[1, 1, 2, 2], [1, 1, 2, 3]],
 [[1, 1, 3, 3]],
 [[1, 2, 2, 2], [1, 2, 2, 3]],
 [[1, 2, 3, 3], [1, 3, 3, 3]],
 [[2, 2, 2, 2], [2, 2, 2, 3]],
 [[2, 2, 3, 3], [2, 3, 3, 3]],
 [[3, 3, 3, 3]]]
10.25355386902811