使用重复的优先级将火炬张量插入 heapq 时出错

Error inserting torch tensors into a heapq using duplicated priorities

如何在这段代码中避免RuntimeError: bool value of Tensor with more than one value is ambiguous

import torch
import heapq

h = []
heapq.heappush(h, (1, torch.Tensor([[1,2]])))
heapq.heappush(h, (1, torch.Tensor([[3,4]])))

这是因为元组之间的比较是在第一个元素相等的情况下比较第二个元素

需要防止heapq在发现重复优先级时尝试比较元组的第二个元素,只需要为我的元素重新定义<运算符

import torch
import heapq

class HeapItem:
    def __init__(self, p, t):
        self.p = p
        self.t = t

    def __lt__(self, other):
        return self.p < other.p

h = []
heapq.heappush(h, HeapItem(1, torch.Tensor([[1,2]])))
heapq.heappush(h, HeapItem(1, torch.Tensor([[3,4]])))