在 Python 中加速 A* 实施
Speed up A* implementation in Python
A* 算法是一种类似于 Dijkstra 算法的寻路算法,其工作原理是访问节点(使用启发式算法来决定下一步访问哪个节点),并将该节点与已访问过的节点进行比较关闭列表。
在我的实现中,随着封闭列表大小的增加,每秒访问的节点数急剧下降。虽然最初,该算法访问了大约 3,000 nodes/second,但随着封闭列表增长 > 10,000 个节点,该数量减少到不到 50 nodes/second。唯一在计算上变得更昂贵的是将新节点与开放和封闭列表进行比较,并将新节点存储在封闭列表中。
因此,我认为我可以通过以更有效的方式存储封闭列表来显着提高性能!
以下是我的一些实施摘录。首先是Nodeclass,用于所有Nodes的定义:
class Node:
"""
A node class for A* Pathfinding
"""
def __init__(self, parent=None, position=None):
self.parent = parent
self.position = position
self.g = 0 # g = actual cost of reaching this node
self.h = 0 # h = heuristic, used for determining which node to visit next
self.f = 0 # f = g + h
def __eq__(self, other):
return self.position == other.position
# defining less than for purposes of heap queue
def __lt__(self, other):
return self.f < other.f
# defining greater than for purposes of heap queue
def __gt__(self, other):
return self.f > other.f
我使用堆队列来存储打开列表,因为我认为这可以提高速度。然而,它只是勉强做到了 (±5%)。
下面是我的 A* 实现,经过压缩,只包含相关操作:
def find_a_star_path(self, current_pos, target_pos):
# Initialize start- and end-nodes with zero cost
start_node = self.Node(None, current_pos)
start_node.g = start_node.h = start_node.f = 0.0
end_node = self.Node(None, target_pos)
end_node.g = end_node.h = end_node.f = 0.0
# Initialize open- and closed list
open_list = []
closed_list = []
# Heapify the open_list and Add the start node
heapq.heapify(open_list)
heapq.heappush(open_list, start_node)
# As long as there are "open" nodes, we continue A*.
while len(open_list) > 0:
# Find node with the lowest cost F, this is visited next
current_node = heapq.heappop(open_list)
closed_list.append(current_node)
if current_node == end_node:
# if current_node = end_node, the process is finished.
# Some code that finds all possible next nodes from the next node
# This next node is called child
# child.g, child.h and child.f are calculated
# Now check if the new node is better than another node with the same position but a different parent.
filtered_open_nodes = (open_node for open_node in open_list if child == open_node)
open_node = next(filtered_open_nodes, None)
while open_node:
if child.f > open_node.f:
add_to_open = False
break
else:
# The new node is better than the other path to this node, so remove it.
open_list.remove(open_node)
open_node = next(filtered_open_nodes, None)
if add_to_open == True:
heapq.heappush(open_list, child)
您的代码存在一些问题:
- 你从未真正使用过
closed_list
;你向它添加节点,但你永远不会检查 current_node
是否已经关闭
- 您的
closed_list
应该是 set
用于 O(1) 查找;但是,这意味着您要么只添加 position
,要么也实施 Node.__hash__
- 通过
open_list.remove(open_node)
,您可能使heapq
中的算法精心维护的堆不变量无效,这也可能导致更长的运行 次或更糟,您的 A* 没有找到正确的结果
- 这里不相关,但是
Node.__eq__
的实现应该与__lt__
和__gt__
(通过不同的属性比较)和__hash__
(未实现)一致
有了这个,还有一些外观上的变化,例如使用 any
检查 open_list
,您的代码可能如下所示:
def find_a_star_path(self, current_pos, target_pos):
start_node = self.Node(None, current_pos)
start_node.g = start_node.h = start_node.f = 0.0
# no need for end node, just compare position
open_list = [start_node] # no need to heapify list with just one element
closed_set = set() # should be a set for O(1) "in" check
while len(open_list) > 0:
current_node = heapq.heappop(open_list)
# checking the position here, alternatively implement Node.__hash__
if current_node.position in closed_set:
continue
closed_set.add(current_node.position)
if current_node.position == target_pos:
# if current_node = end_node, the process is finished.
for child in [code that finds all possible next nodes]:
add_to_open = child.position not in closed_set and \
not any(open_node.f <= child.f for open_node in open_list if open_node == child_node)
if add_to_open:
heapq.heappush(open_list, child)
但这可能仍然很慢,因为每个步骤中整个 open_list
的线性扫描大大超过了 O(logn) 堆操作。通常,你可以只不检查候选节点是否已经在open_list
中,即去掉and not any(...)
检查。这可能会导致堆上的节点比必要的多一些,但这可能根本不是问题。一旦它们被弹出,它们就会被丢弃,因为那时它们已经在 closed_set
中了。 (事实上 ,您可能会删除整个 add_to_open
检查,但是检查它们是否已经在 closed_set
中很便宜,所以为什么不呢。)
如果堆上额外的(可能是重复的)元素导致问题,您可以将 open_list
的 O(n) 扫描替换为 dict
映射位置(或节点,如果他们将 __hash__
) 实现为最小 f
值,则提供 O(1) 查找就像 closed_set
:
closed_set = set()
open_dict = {start_node.position: start_node.f}
然后:
for child in [code that finds all possible next nodes]:
pos = child.position
add_to_open = pos not in closed_set and \
(pos not in open_dict or open_dict[pos] > child.f)
if add_to_open:
heapq.heappush(open_list, child)
open_dict[pos] = child.f
A* 算法是一种类似于 Dijkstra 算法的寻路算法,其工作原理是访问节点(使用启发式算法来决定下一步访问哪个节点),并将该节点与已访问过的节点进行比较关闭列表。
在我的实现中,随着封闭列表大小的增加,每秒访问的节点数急剧下降。虽然最初,该算法访问了大约 3,000 nodes/second,但随着封闭列表增长 > 10,000 个节点,该数量减少到不到 50 nodes/second。唯一在计算上变得更昂贵的是将新节点与开放和封闭列表进行比较,并将新节点存储在封闭列表中。
因此,我认为我可以通过以更有效的方式存储封闭列表来显着提高性能!
以下是我的一些实施摘录。首先是Nodeclass,用于所有Nodes的定义:
class Node:
"""
A node class for A* Pathfinding
"""
def __init__(self, parent=None, position=None):
self.parent = parent
self.position = position
self.g = 0 # g = actual cost of reaching this node
self.h = 0 # h = heuristic, used for determining which node to visit next
self.f = 0 # f = g + h
def __eq__(self, other):
return self.position == other.position
# defining less than for purposes of heap queue
def __lt__(self, other):
return self.f < other.f
# defining greater than for purposes of heap queue
def __gt__(self, other):
return self.f > other.f
我使用堆队列来存储打开列表,因为我认为这可以提高速度。然而,它只是勉强做到了 (±5%)。
下面是我的 A* 实现,经过压缩,只包含相关操作:
def find_a_star_path(self, current_pos, target_pos):
# Initialize start- and end-nodes with zero cost
start_node = self.Node(None, current_pos)
start_node.g = start_node.h = start_node.f = 0.0
end_node = self.Node(None, target_pos)
end_node.g = end_node.h = end_node.f = 0.0
# Initialize open- and closed list
open_list = []
closed_list = []
# Heapify the open_list and Add the start node
heapq.heapify(open_list)
heapq.heappush(open_list, start_node)
# As long as there are "open" nodes, we continue A*.
while len(open_list) > 0:
# Find node with the lowest cost F, this is visited next
current_node = heapq.heappop(open_list)
closed_list.append(current_node)
if current_node == end_node:
# if current_node = end_node, the process is finished.
# Some code that finds all possible next nodes from the next node
# This next node is called child
# child.g, child.h and child.f are calculated
# Now check if the new node is better than another node with the same position but a different parent.
filtered_open_nodes = (open_node for open_node in open_list if child == open_node)
open_node = next(filtered_open_nodes, None)
while open_node:
if child.f > open_node.f:
add_to_open = False
break
else:
# The new node is better than the other path to this node, so remove it.
open_list.remove(open_node)
open_node = next(filtered_open_nodes, None)
if add_to_open == True:
heapq.heappush(open_list, child)
您的代码存在一些问题:
- 你从未真正使用过
closed_list
;你向它添加节点,但你永远不会检查current_node
是否已经关闭 - 您的
closed_list
应该是set
用于 O(1) 查找;但是,这意味着您要么只添加position
,要么也实施Node.__hash__
- 通过
open_list.remove(open_node)
,您可能使heapq
中的算法精心维护的堆不变量无效,这也可能导致更长的运行 次或更糟,您的 A* 没有找到正确的结果 - 这里不相关,但是
Node.__eq__
的实现应该与__lt__
和__gt__
(通过不同的属性比较)和__hash__
(未实现)一致
有了这个,还有一些外观上的变化,例如使用 any
检查 open_list
,您的代码可能如下所示:
def find_a_star_path(self, current_pos, target_pos):
start_node = self.Node(None, current_pos)
start_node.g = start_node.h = start_node.f = 0.0
# no need for end node, just compare position
open_list = [start_node] # no need to heapify list with just one element
closed_set = set() # should be a set for O(1) "in" check
while len(open_list) > 0:
current_node = heapq.heappop(open_list)
# checking the position here, alternatively implement Node.__hash__
if current_node.position in closed_set:
continue
closed_set.add(current_node.position)
if current_node.position == target_pos:
# if current_node = end_node, the process is finished.
for child in [code that finds all possible next nodes]:
add_to_open = child.position not in closed_set and \
not any(open_node.f <= child.f for open_node in open_list if open_node == child_node)
if add_to_open:
heapq.heappush(open_list, child)
但这可能仍然很慢,因为每个步骤中整个 open_list
的线性扫描大大超过了 O(logn) 堆操作。通常,你可以只不检查候选节点是否已经在open_list
中,即去掉and not any(...)
检查。这可能会导致堆上的节点比必要的多一些,但这可能根本不是问题。一旦它们被弹出,它们就会被丢弃,因为那时它们已经在 closed_set
中了。 (事实上 ,您可能会删除整个 add_to_open
检查,但是检查它们是否已经在 closed_set
中很便宜,所以为什么不呢。)
如果堆上额外的(可能是重复的)元素导致问题,您可以将 open_list
的 O(n) 扫描替换为 dict
映射位置(或节点,如果他们将 __hash__
) 实现为最小 f
值,则提供 O(1) 查找就像 closed_set
:
closed_set = set()
open_dict = {start_node.position: start_node.f}
然后:
for child in [code that finds all possible next nodes]:
pos = child.position
add_to_open = pos not in closed_set and \
(pos not in open_dict or open_dict[pos] > child.f)
if add_to_open:
heapq.heappush(open_list, child)
open_dict[pos] = child.f