如何让我的 A 星搜索算法更有效率?

How do I make my A star search algorithm more efficient?

我在 matplotlib 中有一个网格(20*20 或 40*40,取决于用户的选择),其中包含根据 LatLong 位置划分的数据。该网格中的每个单元格代表一个 0.002 或 0.001 的区域(例如:[-70.55, 43.242][-70.548, 43.244])。网格根据阈值着色,比方说高于 30 的是绿色,低于 30 的是红色。

我实施了 A 开始算法,从该图上的一个位置(单元格)转到另一个位置,同时避开所有绿色单元格。在绿色和红色单元格的边界上行驶的成本为 1.3,而对角线成本为 1.5,在两个红色单元格之间行驶的成本为 1。

我正在使用对角线距离启发式算法,对于每个单元格,我都会获取所有可能的邻居并根据阈值设置它们的 G 值。

现在我大部分时间都能找到正确的路径,对于附近的单元格,它的运行时间不到 1 秒。但是当我走得更远时,需要 14-18 秒。

我不明白我在这里做错了什么?我一直在想办法改进它,但失败了。

这是算法的一个片段。我想指出,确定可访问的邻居并设置 G 值在这里可能不是问题,因为每个函数调用的运行时间约为 0.02 - 0.03 秒。

如有任何建议,我们将不胜感激!谢谢

def a_star(self, start, end):
    startNode = Node(None, start)
    endNode = Node(None, end)
    openList = []
    closedList = []
    openList.append(startNode)

    while len(openList) > 0:

        current = openList[0]
        if current.location == endNode.location:
            path = []
            node = current
            while node is not None:
                path.append(node)
                node = node.parent
            return path[::-1]

        openList.pop(0)
        closedList.append(current)

       # This takes around 0.02 0.03 seconds
        neighbours = self.get_neighbors(current)

        for neighbour in neighbours:
            append = True
            for node in closedList:
                if neighbour.location == node.location:
                    append = False
                    break
            for openNode in openList:
                if neighbour.location == openNode.location:
                    if neighbour.g < openNode.g:
                        append = False
                        openNode.g = neighbour.g
                        break
            if append:

                neighbour.h = round(self.heuristic(neighbour.location, endNode.location), 3)
                neighbour.f = round(neighbour.g + neighbour.h, 3)

                bisect.insort_left(openList, neighbour)
    return False

编辑:添加节点片段

 class Node:
    def __init__(self, parent, location):
        self.parent = parent
        self.location = location
        self.g = 0
        self.h = 0
        self.f = 0

编辑 2:添加图像

圆圈是起点,星号是终点。黄色单元格不可访问,因此黄色单元格上没有对角线路径,并且不能在两个黄色单元格之间移动。

这部分非常效率低下。对于每个邻居,您遍历两个相对较大的列表,一旦列表开始增长,这会使整体复杂度非常高:

for node in closedList:
    if neighbour.location == node.location:
        append = False
        break
for openNode in openList:
    if neighbour.location == openNode.location:

基本上,您的算法不应依赖于任何列表。你有你的细胞,你从列表中弹出一个,你有 8 个邻居,你通过与你拥有的细胞进行比较来处理它们,然后其中一些被附加到列表中。无需循环任何列表。

正如@lenic 已经指出的那样,您拥有的内部循环不属于 A* 算法。

第一个 (for node in closedList) 应该替换为检查节点是否在集合中(而不是在列表中):这样就不需要迭代。

当您初始化所有 g 属性 值为无穷大时(起始节点除外),第二个 (openNode in openList) 是不必要的。然后您可以将新的 g 值与已经存储在 neighbor 节点中的值进行比较。

此外,我建议在创建图形后立即为整个图形创建节点。当您需要对同一个图执行多个查询时,这将很有用。

此外,我建议使用 heapq.heappush 而不是 bisect.insort_left。这会更快,因为它并没有真正对队列进行完全排序,而只是确保堆 属性 得到维护。时间复杂度是一样的。更重要的是,从中得到下一个值的时间复杂度优于openList.pop(0).

我建议使用成本 10、13 和 15 而不是 1、1.3 和 1.5,因为整数运算没有精度问题。

出于同样的原因,我不会使用分数位置坐标。因为它们都相距很远(例如 0.002),所以我们可以只对两个坐标使用顺序整数(0、1、2、...)。一个额外的函数可以采用解决方案和参考坐标对将这些整数坐标转换回 "world" 坐标。

我做了一些假设:

  • 严禁通过"hostile"个单元格。不存在用于交叉这些边的边。例如,如果起始位置在四个敌方单元的中间,则没有路径。人们可以考虑一种替代方案,其中这些边将获得极高的成本,以便您始终能够提出一条路径。

  • 网格边界上的直边都是允许的。当与它相邻的单个细胞是敌对的时,它们的成本将为 1.3 (13)。所以实际上在两个维度中的每一个都有一个比细胞多一个位置

  • 阈值输入将是一个介于 0 和 1 之间的分数,表示相对于单元格总数应该友好的单元格的分数,这个值将被转换为 "split"区分友好和敌对细胞的值。

以下是您可以用作灵感的代码:

from heapq import heappop, heappush

class Node:
    def __init__(self, location):
        self.location = location
        self.neighbors = []
        self.parent = None
        self.g = float('inf')
        self.f = 0

    def clear(self):
        self.parent = None
        self.g = float('inf')
        self.f = 0

    def addneighbor(self, cost, other):
        # add edge in both directions
        self.neighbors.append((cost, other))
        other.neighbors.append((cost, self))

    def __gt__(self, other):  # make nodes comparable
        return self.f > other.f

    def __repr__(self):
        return str(self.location)

class Graph:
    def __init__(self, grid, thresholdfactor):
        # get value that corresponds with thresholdfactor (which should be between 0 and 1)
        values = sorted([value for row in grid for value in row])
        splitvalue = values[int(len(values) * thresholdfactor)]
        print("split at ", splitvalue)
        # simplify grid values to booleans and add extra row/col of dummy cells all around
        width = len(grid[0]) + 1
        height = len(grid) + 1
        colors = ([[False] * (width + 1)] +
            [[False] + [value < splitvalue for value in row] + [False] for row in grid] +
            [[False] * (width + 1)])

        nodes = []
        for i in range(height):
            noderow = []
            nodes.append(noderow)
            for j in range(width):
                node = Node((i, j))
                noderow.append(node)
                cells = [colors[i+1][j]] + colors[i][j:j+2]  # 3 cells around location: SW, NW, NE
                for di, dj in ((1, 0), (0, 0), (0, 1), (0, 2)):  # 4 directions: W, NW, N, NE
                    cost = 0
                    if (di + dj) % 2:  # straight
                        # if both cells are hostile, then not allowed
                        if cells[0] or cells[1]:  # at least one friendly cell
                            # if one is hostile, higher cost
                            cost = 13 if cells[0] != cells[1] else 10
                        cells.pop(0)
                    elif cells[0]:  # diagonal: cell must be friendly
                        cost = 15
                    if cost:
                        node.addneighbor(cost, nodes[i-1+di][j-1+dj])
        self.nodes = nodes

    @staticmethod
    def reconstructpath(node):
        path = []
        while node is not None:
            path.append(node)
            node = node.parent
        path.reverse()
        return path

    @staticmethod
    def heuristic(a, b):
        # optimistic score, assuming all cells are friendly
        dy = abs(a[0] - b[0])
        dx = abs(a[1] - b[1])
        return min(dx, dy) * 15 + abs(dx - dy) * 10

    def clear(self):
        # remove search data from graph 
        for row in self.nodes:
            for node in row:
                node.clear()

    def a_star(self, start, end):
        self.clear()
        startnode = self.nodes[start[0]][start[1]]
        endnode = self.nodes[end[0]][end[1]]
        startnode.g = 0
        openlist = [startnode] 
        closed = set()
        while openlist:
            node = heappop(openlist)
            if node in closed:
                continue
            closed.add(node)
            if node == endnode:
                return self.reconstructpath(endnode)
            for weight, neighbor in node.neighbors:
                g = node.g + weight
                if g < neighbor.g:
                    neighbor.g = g
                    neighbor.f = g + self.heuristic(neighbor.location, endnode.location)
                    neighbor.parent = node
                    heappush(openlist, neighbor)

我将您包含的图形编码为图像,以查看代码的行为方式:

grid = [
    [38, 32, 34, 24,  0, 82,  5, 41, 11, 32,  0, 16,  0,113, 49, 34, 24,  6, 15, 35],
    [61, 61,  8, 35, 65, 31, 53, 25, 66,  0, 21,  0,  9,  0, 31, 75, 20,  8,  3, 29],
    [43, 66, 47,114, 38, 41,  1,108,  9,  0,  0,  0, 39,  0, 27, 72, 19, 14, 24, 25],
    [45,  5, 37, 23,102, 25, 49, 34, 41, 49, 35, 15, 29, 21, 66, 67, 44, 31, 38, 91],
    [47, 94, 48, 69, 33, 95, 18, 75, 28, 70, 38, 78, 48, 88, 21, 66, 44, 70, 75, 23],
    [23, 84, 53, 23, 92, 14, 71, 12,139, 30, 63, 82, 16, 49, 76, 56,119,100, 47, 21],
    [30,  0, 32, 90,  0,195, 85, 65, 18, 57, 47, 61, 40, 32,109,255, 88, 98, 39,  0],
    [ 0,  0,  0,  0, 39, 39, 76,167, 73,140, 58, 56, 94, 61,212,222,141, 50, 41, 20],
    [ 0,  0,  0,  5,  0,  0, 21,  2,132,100,218, 81,  0, 62,135, 42,131, 80, 14, 19],
    [ 0,  0,  0,  0,  0, 15,  9, 55, 70, 71, 42,117, 65, 63, 59, 81,  4, 40, 77, 46],
    [ 0,  0,  0,  0, 55, 52,101, 93, 30,166, 56, 19, 76,103, 54, 37, 24, 23, 59, 98],
    [ 0,  0,  0,  0,  9,  9, 44,149, 11,134, 90, 64, 44, 57, 61, 79,270,201, 84,  6],
    [ 0,  0,  0, 22,  1, 15,  0, 25, 30,101,154, 60, 97, 64, 15,162, 27, 91, 71,  0],
    [ 0,  0,  1, 35,  5, 10,  0, 55, 25,  0,200, 81, 31, 53, 42, 74,127,154,  7,  0],
    [ 0,  0,187, 17, 45, 66, 91,191, 70,189, 18, 25, 67, 32, 40, 79,103, 79, 59,  0],
    [ 0, 21, 16, 14, 19, 58,278, 56,128, 95,  3, 52,  9, 27, 25, 43, 62, 25, 38,  0],
    [ 4,  3, 11, 26,119,165, 53, 85, 46, 81, 19, 11, 12, 19, 18,  9, 16,  6, 37,  0],
    [ 5,  0,  0, 65,158,153,118, 38,123, 46, 28, 24,  0, 21, 11, 20,  5,  1, 10,  0],
    [17,  4, 28, 81,101,101, 46, 25, 44, 12, 41,  6, 27,  8,  4, 32, 40,  1,  1,  0],
    [26, 20, 84, 42,112, 27, 14, 16,  5, 13,  3, 43,  6, 18, 12, 44,  5,  0,  0,  5]
]

graph = Graph(grid, 0.5) # provide the threshold at initialisation
path = graph.a_star((7, 4), (14, 18))
print(path)