如何有效地将 2 元组的所有串联计数为 Python 中的更长链

How to efficiently count all concatenations of 2-tuples into longer chains in Python

假设我们想要构建一个长(金属)链,它将由更小的 link 组成,链接在一起。我知道链条的长度应该是多少:n。 link 表示为二元组:(a, b)。我们可以将 link 链接在一起,当且仅当它们在链接的一侧共享相同的元素时。
我得到了一个长度为 n-1 - links 的列表列表 - 它代表了我在链的每个位置可用的所有 link。例如:

links = [
    [
        ('a', 1),
        ('a', 2),
        ('a', 3),
        ('b', 1),
    ],
    [
        (1, 'A'),
        (2, 'A'),
        (2, 'B'),
    ],
    [
        ('A', 'a'),
        ('B', 'a'),
        ('B', 'b'),
    ]
]

在这种情况下,最终链的长度将为:n = 4.
这里我们可能会生成这些可能的链:

('a', 1, 'A', 'a')
('b', 1, 'A', 'a')
('a', 2, 'A', 'a')
('a', 2, 'B', 'a')
('a', 2, 'B', 'b')

这个过程与用多米诺拼图排成一排很相似,但是我不能旋转瓷砖。

我的任务是,给定这样一个输入列表,我需要计算所有可能创建的长度为 n 的不同链。上面的案例是一个简化的玩具示例,但实际上链条的长度可能高达 1000,我可以在每个特定位置使用数十个不同的 link。但是,我知道 肯定 对于位置 i 可用的每个 link 在位置 [=] 存在另一个 link 27=]i-1与它兼容。

我写了一个非常天真的解决方案,从头到尾遍历所有 links 并将它们合并在一起,增加最终链的所有可能版本:


    # THIS CODE WAS ORIGINALLY BUGGED ONCE I POSTED IT
    # BUT IS FIXED NOW

    # initiate chains with links that could make up
    # the first position, then: iteratively grow them
    chains = links[0]
    
    # seach for all possible paths:
    # iterate over all positions
    for position in links[1:]:
        
        # temp array to help me grow the chain
        temp = []
            
        # over each chain in the current set of chains
        for chain in chains:

            # over each link in a given position
            for link in position:
                
                # check if the chain and link are chainable
                if chain[-1] == link[0]:
                    
                    # append new link to a pre-existing chain
                    temp.append(chain + tuple([link[1]]))
        
        # overwrite the current list of chains
        chains = temp

这个解决方案工作正常,即我非常相信它 returns 是正确的结果。但是,它非常慢,我需要加快速度,最好是~100x。因此我认为我需要使用一种智能算法来计算所有的可能性,而不是像上面那样的强力连接......因为我只需要计算链条,而不是枚举它们,也许会有一个回溯程序开始从每个可能的最终 link 并沿途增加可能性;最后将所有最终 link 加起来?我有一些模糊的想法,但无法真正确定...

数数就够了,就这么算吧,再大的情况下也是一瞬间。

from collections import defaultdict, Counter

def count_chains(links):
    chains = defaultdict(lambda: 1)
    for position in links:
        temp = Counter()
        for a, b in position:
            temp[b] += chains[a]
        chains = temp
    return sum(chains.values())

它和你的几乎一样,除了 chains 是一个 list 链结束于一些 b-values,我正在使用Counter of chains ending in some b-values: chains[b] 告诉我有多少条链以 b 结尾。 Counters(和 defaultdict)是字典,所以我不必搜索和检查匹配的连接器,我只是查找它们

向后兼容性意味着我们可能会更好地向后退,所以我们不会追踪死胡同,但我认为这根本不会有多大帮助(取决于您的数据)。

例如links = [[(1, 1), (1, 2), (2, 1), (2, 2)]] * 1000,计算链数大约需要2毫秒,即:

21430172143725346418968500981200036211228096234110672148875007767407021022498722449863967576313917162551893458351062936503742905713846280871969155149397149607869135549648461970842149210124742283755908364306092949967163882534797535118331087892154125829142392955373084335320859663305248773674411336138752

Try it online!

这是我使用图形数据结构方法的解决方案,它比你的具有立方时间复杂度 O(n3) 的方法更有效。

import timeit

class Node:
    def __init__(self, val, next=None):
        self.val = val
        self.next = next if next else []
        self.visited = False
        self.mem = None
    def __repr__(self):
        return f'<{self.val} {self.mem}>'
        
def epsi_sol(links, len_):
    nodes = {}
    start_nodes = set(i[0] for i in links[0]) # {'a', 'b'}

    # constructiong the graph
    for i in links:
        for j in i:
            if j[0] not in nodes:
                nodes[j[0]] = Node(j[0])
            if j[1] not in nodes:
                nodes[j[1]] = Node(j[1])

            nodes[j[0]].next.append(nodes[j[1]])

    def find_chain_with_length(node, length, valid_length):
        if length +1 == valid_length:
            return 1

        # if already visited just return
        if node.visited:
            return 0 
        
        if node.mem is not None:
            return node.mem

        # if this is not leaf node
        # we will mark it visited
        node.visited = True
        temp_count = 0
        for each_neighbor in node.next:
            temp_count += find_chain_with_length(each_neighbor, length+1, valid_length)
        # after visiting mark it unvisited
        node.visited = False
        node.mem = temp_count
        return temp_count

    solution_count = 0
    for each_start_node in start_nodes:
        solution_count += find_chain_with_length(nodes[each_start_node],0, len_)
    return solution_count

from collections import defaultdict, Counter

def kelly_sol(links, len_):
    chains = defaultdict(lambda: 1)
    for position in links[:len_]:
        temp = Counter()
        for a, b in position:
            temp[b] += chains[a]
        chains = temp
    return sum(chains.values())

def mac_sol(links, len_):
    chains = links[0]
    
    # seach for all possible paths:
    # iterate over all positions
    for position in links[1:]:
        
        # temp array to help me grow the chain
        temp = []
            
        # over each chain in the current set of chains
        for chain in chains:

            # over each link in a given position
            for link in position:
                
                # check if the chain and link are chainable
                if chain[-1] == link[0]:
                    
                    # append new link to a pre-existing chain
                    temp.append(chain + tuple([link[1]]))
        
        # overwrite the current list of chains
        chains = temp
    return len(chains)

# tests
for n in range(100, 1000, 100):
    links = [[(f'1_{i}', f'1_{i+1}'), (f'1_{i}', f'2_{i+1}'), (f'2_{i}', f'1_{i+1}'), (f'2_{i}', f'2_{i+1}')] for i in range(n)]
    print('-'*50)
    print(f'kelly_sol({n}) => {timeit.timeit(lambda: kelly_sol(links, n+1), number=2)} seconds')
    print(f'epsi_sol({n}) => {timeit.timeit(lambda: epsi_sol(links, n+1), number=2)} seconds')
    if n <50:
        print(f'mac_sol({n}) => {timeit.timeit(lambda: mac_sol(links, n+1), number=2)} seconds')
    print('-'*50)
--------------------------------------------------
kelly_sol(100) => 0.0013752000000000209 seconds
epsi_sol(100) => 0.0022567999999996147 seconds
--------------------------------------------------
--------------------------------------------------
kelly_sol(200) => 0.0026332999999993945 seconds
epsi_sol(200) => 0.004522500000000207 seconds
--------------------------------------------------
--------------------------------------------------
kelly_sol(300) => 0.003924899999999454 seconds
epsi_sol(300) => 0.006861199999999457 seconds
--------------------------------------------------
--------------------------------------------------
kelly_sol(400) => 0.005278099999999952 seconds
epsi_sol(400) => 0.012999699999999947 seconds
--------------------------------------------------
--------------------------------------------------
kelly_sol(500) => 0.006728900000000593 seconds
epsi_sol(500) => 0.01406989999999908 seconds
--------------------------------------------------
--------------------------------------------------
kelly_sol(600) => 0.00828249999999997 seconds
epsi_sol(600) => 0.015398799999999824 seconds
--------------------------------------------------
--------------------------------------------------
kelly_sol(700) => 0.009703200000000578 seconds
epsi_sol(700) => 0.01597070000000045 seconds
--------------------------------------------------
--------------------------------------------------
kelly_sol(800) => 0.009961999999999804 seconds
epsi_sol(800) => 0.0196051999999991 seconds
--------------------------------------------------
--------------------------------------------------
kelly_sol(900) => 0.014800799999999725 seconds
epsi_sol(900) => 0.02183789999999952 seconds
--------------------------------------------------