在 Python 中加速 A* 实施

Speed up A* implementation in Python

A* 算法是一种类似于 Dijkstra 算法的寻路算法,其工作原理是访问节点(使用启发式算法来决定下一步访问哪个节点),并将该节点与已访问过的节点进行比较关闭列表。

在我的实现中,随着封闭列表大小的增加,每秒访问的节点数急剧下降。虽然最初,该算法访问了大约 3,000 nodes/second,但随着封闭列表增长 > 10,000 个节点,该数量减少到不到 50 nodes/second。唯一在计算上变得更昂贵的是将新节点与开放和封闭列表进行比较,并将新节点存储在封闭列表中。

因此,我认为我可以通过以更有效的方式存储封闭列表来显着提高性能!

以下是我的一些实施摘录。首先是Nodeclass,用于所有Nodes的定义:

class Node:
    """
    A node class for A* Pathfinding
    """

    def __init__(self, parent=None, position=None):
        self.parent = parent
        self.position = position

        self.g = 0 # g = actual cost of reaching this node
        self.h = 0 # h = heuristic, used for determining which node to visit next
        self.f = 0 # f = g + h

    def __eq__(self, other):
        return self.position == other.position

    # defining less than for purposes of heap queue
    def __lt__(self, other):
        return self.f < other.f

    # defining greater than for purposes of heap queue
    def __gt__(self, other):
        return self.f > other.f

我使用堆队列来存储打开列表,因为我认为这可以提高速度。然而,它只是勉强做到了 (±5%)。

下面是我的 A* 实现,经过压缩,只包含相关操作:

def find_a_star_path(self, current_pos, target_pos):
    # Initialize start- and end-nodes with zero cost
    start_node = self.Node(None, current_pos)
    start_node.g = start_node.h = start_node.f = 0.0

    end_node = self.Node(None, target_pos)
    end_node.g = end_node.h = end_node.f = 0.0

    # Initialize open- and closed list
    open_list = []
    closed_list = []

    # Heapify the open_list and Add the start node
    heapq.heapify(open_list)
    heapq.heappush(open_list, start_node)

    # As long as there are "open" nodes, we continue A*.
    while len(open_list) > 0:
        # Find node with the lowest cost F, this is visited next
        current_node = heapq.heappop(open_list)
        closed_list.append(current_node)

        if current_node == end_node:
            # if current_node = end_node, the process is finished.

        # Some code that finds all possible next nodes from the next node
        # This next node is called child
        # child.g, child.h and child.f are calculated

        # Now check if the new node is better than another node with the same position but a different parent.

        filtered_open_nodes = (open_node for open_node in open_list if child == open_node)
        open_node = next(filtered_open_nodes, None)

        while open_node:
            if child.f > open_node.f:
                add_to_open = False
                break
            else:
                # The new node is better than the other path to this node, so remove it.
                open_list.remove(open_node)
                open_node = next(filtered_open_nodes, None)

        if add_to_open == True:
            heapq.heappush(open_list, child)

您的代码存在一些问题:

  • 你从未真正使用过 closed_list;你向它添加节点,但你永远不会检查 current_node 是否已经关闭
  • 您的 closed_list 应该是 set 用于 O(1) 查找;但是,这意味着您要么只添加 position,要么也实施 Node.__hash__
  • 通过open_list.remove(open_node),您可能使heapq中的算法精心维护的堆不变量无效,这也可能导致更长的运行 次或更糟,您的 A* 没有找到正确的结果
  • 这里不相关,但是Node.__eq__的实现应该与__lt____gt__(通过不同的属性比较)和__hash__(未实现)一致

有了这个,还有一些外观上的变化,例如使用 any 检查 open_list,您的代码可能如下所示:

def find_a_star_path(self, current_pos, target_pos):
    start_node = self.Node(None, current_pos)
    start_node.g = start_node.h = start_node.f = 0.0
    # no need for end node, just compare position

    open_list = [start_node] # no need to heapify list with just one element
    closed_set = set() # should be a set for O(1) "in" check

    while len(open_list) > 0:
        current_node = heapq.heappop(open_list)
        
        # checking the position here, alternatively implement Node.__hash__
        if current_node.position in closed_set:
            continue
        closed_set.add(current_node.position)

        if current_node.position == target_pos:
            # if current_node = end_node, the process is finished.

        for child in [code that finds all possible next nodes]:

            add_to_open = child.position not in closed_set and \
                not any(open_node.f <= child.f for open_node in open_list if open_node == child_node)

            if add_to_open:
                heapq.heappush(open_list, child)

但这可能仍然很慢,因为每个步骤中整个 open_list 的线性扫描大大超过了 O(logn) 堆操作。通常,你可以只检查候选节点是否已经在open_list中,即去掉and not any(...)检查。这可能会导致堆上的节点比必要的多一些,但这可能根本不是问题。一旦它们被弹出,它们就会被丢弃,因为那时它们已经在 closed_set 中了。 (事实上​​ ,您可能会删除整个 add_to_open 检查,但是检查它们是否已经在 closed_set 中很便宜,所以为什么不呢。)

如果堆上额外的(可能是重复的)元素导致问题,您可以将 open_list 的 O(n) 扫描替换为 dict 映射位置(或节点,如果他们将 __hash__) 实现为最小 f 值,则提供 O(1) 查找就像 closed_set:

    closed_set = set()
    open_dict = {start_node.position: start_node.f}

然后:

    for child in [code that finds all possible next nodes]:
        pos = child.position
        add_to_open = pos not in closed_set and \
                (pos not in open_dict or open_dict[pos] > child.f)

        if add_to_open:
            heapq.heappush(open_list, child)
            open_dict[pos] = child.f