我如何计算满足这些约束的序列?

How can I count sequences that meet these constraints?

我正在尝试计算 IO 符号序列的排列,例如人们进入(I 表示“进”)和离开(O 表示“出”)房间。对于给定的 nI 个符号,应该有 O 个符号,给出序列的总长度 2*n 。此外,在有效排列中的任何一点,O 符号的数量必须小于或等于 I 符号的数量(因为有人不可能离开房间时空)。

此外,我有一些IO符号的初始前缀,代表之前进入或离开房间的人。输出应该只计算以该前缀开头的序列。

例如,对于n=1和初始状态'',结果应该是1,因为唯一有效的序列是IO;对于 n=3II 的初始状态,可能的排列是

IIIOOO
IIOIOO
IIOOIO

的结果是 3。 (三个人进出房间有五种方式,但另外两种涉及第一个人立即离开。)

我猜想解决这个问题的最简单方法是使用 itertools.permutations。到目前为止,这是我的代码:

n=int(input())  ##actual length will be 2*n
string=input()
I_COUNT=string.count("I")
O_COUNT=string.count("O")
if string[0]!="I":
 sys.exit()
if O_COUNT>I_COUNT:
 sys.exit()
perms = [''.join(p) for p in permutations(string)]
print(perms)

目标是获取字符串中遗漏的任何内容的排列并将其附加到用户的输入中,那么如何将用户的输入附加到字符串的剩余长度并获取排列计数?

@cache
def count_permutations(ins: int, outs: int):
    # ins and outs are the remaining number of ins and outs to process
    assert outs >= ins
    if ins == 0 :
        # Can do nothing but output "outs"
        return 1
    elif outs == ins:
        # Your next output needs to be an I else you become unbalanced
        return count_permutations(ins - 1, outs)
    else:
        # Your. next output can either be an I or an O
        return count_permutations(ins - 1, outs) + count_permutations(ins, outs - 1)

如果,假设你总共有5个I和5个Os,并且你已经输出了一个I,那么你想要:count_permutations(4, 5).

这是一个动态规划问题。

鉴于剩余的进出操作数,我们执行以下操作之一:

  1. 如果我们进退两难,只能使用其他类型的操作。只有一个可能的分配。

  2. 如果我们有相同数量的ins或outs,我们必须根据问题的约束使用一个in操作。

  3. 最后,如果我们的输入多于输出,我们可以执行任一操作。那么,答案就是我们选择使用 in 操作时的序列数加上我们选择使用 out 操作时的序列数。

这会在 O(n^2) 时间内运行,尽管在实践中可以使用 2D 列表而不是缓存注释使以下代码片段更快(我在这种情况下使用 @cache使重复更容易理解)。

from functools import cache

@cache
def find_permutation_count(in_remaining, out_remaining):
    if in_remaining == 0 or out_remaining == 0:
        return 1
    elif in_remaining == out_remaining:
        return find_permutation_count(in_remaining - 1, out_remaining)
    else:
        return find_permutation_count(in_remaining - 1, out_remaining) + find_permutation_count(in_remaining, out_remaining - 1)
    
print(find_permutation_count(3, 3)) # prints 5

I'm guessing the simplest way to solve this is using itertools.permutations

遗憾的是,这不会很有帮助。问题是 itertools.permutations 不关心它正在排列的元素的值;无论如何,它都将它们视为不同的。所以如果你有 6 个输入元素,并要求长度为 6 的排列,你将得到 720 个结果,即使所有输入都相同。

itertools.combinations 有相反的问题;它不区分 any 元素。当它选择一些元素时,它只会将这些元素按照它们最初出现的顺序排列。所以如果你有 6 个输入元素并要求长度为 6 的组合,你将得到 1 个结果 - 原始序列。

大概你想要做的是生成排列 Is 和 Os 的所有不同方式,然后取出无效的,然后计算剩下的。这是可能的,itertools.

直接使用递归算法会更简单。一般做法如下:

  • 在任何给定时间,我们都关心房间里有多少人,还有多少人必须进入。为了处理这个前缀,我们简单地计算一下现在房间里有多少人,然后从总人数中减去这个数,以确定还有多少人必须进入。我将输入处理留作练习。
  • 为了确定计数,我们计算涉及下一个动作的方式是 I(有人进来),加上涉及下一个动作的方式是 O(有人离开) .
  • 如果每个人都进入了,那么前进的道路只有一个:每个人都必须离开,一次一个。这是一个基本案例。
  • 否则肯定有人进来的可能,我们递归统计其他人进来的方式after;递归调用,房间里多了一个人,少了一个人还得进去。
  • 如果还有人要进,而且现在房间里也有人,那么有人先走也是可以的。之后我们递归计算其他人进入的方式;递归调用,房间里少了一个人,同样的人数还要进。

这相当直接地转化为代码:

def ways_to_enter(currently_in, waiting):
    if waiting == 0:
        return 1
    result = ways_to_enter(currently_in + 1, waiting - 1)
    if currently_in > 0:
        result += ways_to_enter(currently_in - 1, waiting)
    return result

一些测试:

>>> ways_to_enter(0, 1) # n = 1, prefix = ''
1
>>> ways_to_enter(2, 1) # n = 3, prefix = 'II'; OR e.g. n = 4, prefix = 'IIOI'
3
>>> ways_to_enter(0, 3) # n = 3, prefix = ''
5
>>> ways_to_enter(0, 14) # takes less than a second on my machine
2674440

我们可以通过使用 functools.cache (lru_cache prior to 3.9), which will memoize results of the previous recursive calls. The more purpose-built approach is to use dynamic programming 技术修饰函数来提高较大值的性能:在这种情况下,我们将为 ways_to_enter(x, y) 的结果初始化二维存储,并计算这些值一次一个,这样“递归调用”所需的值已经在过程的早期完成。

这种直接方法类似于:

def ways_to_enter(currently_in, waiting):
    # initialize storage
    results = [[0] * currently_in for _ in waiting]
    # We will iterate with `waiting` as the major axis.
    for w, row in enumerate(results):
        for c, column in enumerate(currently_in):
            if w == 0:
                value = 1
            else:
                value = results[w - 1][c + 1]
                if c > 0:
                    value += results[w][c - 1]
            results[w][c] = value
    return results[-1][-1]

长度为 2n 的此类排列的数量由第 nCatalan number 给出。维基百科根据中心二项式系数给出了加泰罗尼亚数的公式:

from math import comb

def count_permutations(n):
  return comb(2*n,n) // (n+1)

for i in range(1,10):
  print(i, count_permutations(i))

# 1 1
# 2 2
# 3 5
# 4 14
# 5 42
# 6 132
# 7 429
# 8 1430
# 9 4862  

itertools 中的 product() 函数将允许您生成给定长度的 'I' 和 'O' 的所有可能序列。

从该列表中,您可以按以用户提供的 start_seq.

开头的序列进行过滤

根据您对 'I' 和 'O' 的数量和顺序的规则,您可以从该列表中筛选出有效的序列:

from itertools import product


def is_valid(seq):
    '''Evaluates a sequence I's and O's following the rules that:
        - there cannot be more outs than ins
        - the ins and outs must be balanced
    '''
    _in, _out = 0, 0
    for x in seq:
        if x == 'I':
            _in += 1
        else:
            _out += 1

        if (_out > _in) or (_in > len(seq)/2):
            return False

    return True


# User inputs...
start_seq = 'II'

assert start_seq[0] != 'O', 'Starting sequence cannot start with an OUT.'

n = 3
total_len = n*2

assert len(start_seq) < total_len, 'Starting sequence is at least as big as total number, nothing to iterate.'

# Calculate all possible sequences that are total_len long, as tuples of 'I' and 'O'
seq_tuples = product('IO', repeat=total_len)

# Convert tuples to strings, e.g., `('I', 'O', 'I')` to `'IOI'`
sequences = [''.join(seq_tpl) for seq_tpl in seq_tuples]

# Filter for sequences that start correctly
sequences = [seq for seq in sequences if seq.startswith(start_seq)]

# Filter for valid sequences
sequences = [seq for seq in sequences if is_valid(seq)]

print(sequences)

我得到:

['IIIOOO', 'IIOIOO', 'IIOOIO']

也许不是很优雅,但这似乎确实满足了要求:

from itertools import permutations

def isvalid(start, p):
    for c1, c2 in zip(start, p):
        if c1 != c2:
            return 0
    n = 0
    for c in p:
        if c == 'O':
            if (n := n - 1) < 0:
                return 0
        else:
            n += 1
    return 1

def calc(n, i):
    s = i + 'I' * (n - i.count('I'))
    s += 'O' * (n * 2 - len(s))
    return sum(isvalid(i, p) for p in set(permutations(s)))

print(calc(3, 'II'))
print(calc(3, 'IO'))
print(calc(3, 'I'))
print(calc(3, ''))

输出:

3
2
5
5
def solve(string,n):
  countI =string.count('I')
  if countI==n:
    return 1
  countO=string.count('O')
  if countO > countI:
    return 0
  k= solve(string + 'O',n)
  h= solve(string + 'I',n)
  return k+h

n= int(input())
string=input()
print(solve(string,n))