将 PriorityQueue 与对象一起使用时出现问题 - Python

Problem Using PriorityQueue with Objects - Python

我正在尝试创建统一成本搜索算法。但是我在优先级队列中存储节点时遇到问题。

它在节点 D 之前运行良好,如提供的输出所示,我不确定为什么。任何帮助将不胜感激。

错误说它无法比较节点,但我将它们添加为元组,因此它可以使用距离进行比较

 class GraphEdge(object):
    def __init__(self, destinationNode, distance):
        self.node = destinationNode
        self.distance = distance

class GraphNode(object):
    def __init__(self, val):
        self.value = val
        self.edges = []

    def add_child(self, node, distance):
        self.edges.append(GraphEdge(node, distance))

    def remove_child(self, del_node):
        if del_node in self.edges:
            self.edges.remove(del_node)    

class Graph(object):
    def __init__(self, node_list):
        self.nodes = node_list
   
    def add_edge(self, node1, node2, distance):
        if node1 in self.nodes and node2 in self.nodes:
            node1.add_child(node2, distance)
            node2.add_child(node1, distance)

    def remove_edge(self, node1, node2):
        if node1 in self.nodes and node2 in self.nodes:
            node1.remove_child(node2)
            node2.remove_child(node1)

from queue import PriorityQueue

def build_path(root_node, goal_node):
    path = [goal_node]    
    add_parent(root_node, goal_node, path)
    return path

def add_parent(root_node, node, path):
    parent = node.parent
    path.append(parent)
    if parent == root_node:        
        return      
    else:
        add_parent(root_node, parent, path)
    

def ucs_search(root_node, goal_node):
    visited = set()                         
    queue = PriorityQueue()
    queue.put((0, root_node))
    visited_order = []
    
    while not queue.empty():
        current_node_priority, current_node  = queue.get()
        
        visited.add(current_node)
        visited_order.append(current_node.value)
        print("current_node:", current_node.value)

        if current_node == goal_node:
            print(visited_order)
            return current_node, build_path(root_node, goal_node)        
       
        for edge in current_node.edges:
            child = edge.node
            
            if child not in visited:          
                child.parent = current_node
                print("child:", child.value)
                queue.put(((current_node_priority + edge.distance), child))

node_u = GraphNode('U')
node_d = GraphNode('D')
node_a = GraphNode('A')
node_c = GraphNode('C')
node_i = GraphNode('I')
node_t = GraphNode('T')
node_y = GraphNode('Y')

graph = Graph([node_u, node_d, node_a, node_c, node_i, node_t, node_y])

graph.add_edge(node_u, node_a, 4)
graph.add_edge(node_u, node_c, 6)
graph.add_edge(node_u, node_d, 3)
graph.add_edge(node_d, node_c, 4)
graph.add_edge(node_a, node_i, 7)
graph.add_edge(node_c, node_i, 4)
graph.add_edge(node_c, node_t, 5)
graph.add_edge(node_i, node_y, 4)
graph.add_edge(node_t, node_y, 5)

goal, sequence  = ucs_search(node_a, node_y)

输出:

current_node: A
child: U
child: I
current_node: U
child: C
child: D
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-52-2d575db64232> in <module>
     19 graph.add_edge(node_t, node_y, 5)
     20 
---> 21 goal, sequence  = ucs_search(node_a, node_y)

<ipython-input-51-b26ec19983b6> in ucs_search(root_node, goal_node)
     36                 child.parent = current_node
     37                 print("child:", child.value)
---> 38                 queue.put(((current_node_priority + edge.distance), child))
     39 

~\AppData\Local\Continuum\anaconda3\lib\queue.py in put(self, item, block, timeout)
    147                             raise Full
    148                         self.not_full.wait(remaining)
--> 149             self._put(item)
    150             self.unfinished_tasks += 1
    151             self.not_empty.notify()

~\AppData\Local\Continuum\anaconda3\lib\queue.py in _put(self, item)
    231 
    232     def _put(self, item):
--> 233         heappush(self.queue, item)
    234 
    235     def _get(self):

TypeError: '<' not supported between instances of 'GraphNode' and 'GraphNode'

如果队列中的两个元组距离相同,则优先级队列需要根据对应GraphNode的优先级值进行决胜局。由于 __lt__ 函数没有为 GraphNode 定义,这将导致错误。 (__lt__ 函数定义了如何使用 < 运算符比较两个 GraphNode。)

要解决,请为 GraphNode class 定义 __lt__ 函数。这是 Python 在比较两个 GraphNodes:

时调用的函数
class GraphNode(object):
    def __init__(self, val):
        self.value = val
        self.edges = []

    def add_child(self, node, distance):
        self.edges.append(GraphEdge(node, distance))

    def remove_child(self, del_node):
        if del_node in self.edges:
            self.edges.remove(del_node)

    def __lt__(self, other):
        return self.value < other.value