如何实现 numba jitted 优先级队列?

How can I implement a numba jitted priority queue?

我无法实现 numba jitted 优先级队列。

严重抄袭自 python docs,我对这个 class 相当满意。

import itertools

import numba as nb
from numba.experimental import jitclass
from typing import List, Tuple, Dict
from heapq import heappush, heappop


class PurePythonPriorityQueue:
    def __init__(self):
        self.pq = [] # list of entries arranged in a heap
        self.entry_finder = {}  # mapping of indices to entries
        self.REMOVED = -1 # placeholder for a removed item
        self.counter = itertools.count() # unique sequence count

    def put(self, item: Tuple[int, int], priority: float = 0.0):
        """Add a new item or update the priority of an existing item"""
        if item in self.entry_finder:
            self.remove_item(item)
        count = next(self.counter)
        entry = [priority, count, item]
        self.entry_finder[item] = entry
        heappush(self.pq, entry)

    def remove_item(self, item: Tuple[int, int]):
        """Mark an existing item as REMOVED.  Raise KeyError if not found."""
        entry = self.entry_finder.pop(item)
        entry[-1] = self.REMOVED

    def pop(self):
        """Remove and return the lowest priority item. Raise KeyError if empty."""
        while self.pq:
            priority, count, item = heappop(self.pq)
            if item is not self.REMOVED:
                del self.entry_finder[item]
                return item
        raise KeyError("pop from an empty priority queue")

现在我想从一个 numba jitted 函数中调用它来做繁重的数值工作,所以我试着把它变成一个 numba jitclass。由于在 vanilla python 实现中条目是异构列表,我想我也应该实现其他 jitclasses。但是,我得到了 Failed in nopython mode pipeline (step: nopython frontend)(下面的完整跟踪)。

这是我的尝试:

@jitclass
class Item:
    i: int
    j: int

    def __init__(self, i, j):
        self.i = i
        self.j = j


@jitclass
class Entry:
    priority: float
    count: int
    item: Item
    removed: bool

    def __init__(self, p: float, c: int, i: Item):
        self.priority = p
        self.count = c
        self.item = i
        self.removed = False


@jitclass
class PriorityQueue:
    pq: List[Entry]
    entry_finder: Dict[Item, Entry]
    counter: int

    def __init__(self):
        self.pq = nb.typed.List.empty_list(Entry(0.0, 0, Item(0, 0)))
        self.entry_finder = nb.typed.Dict.empty(Item(0, 0), Entry(0, 0, Item(0, 0)))
        self.counter = 0

    def put(self, item: Item, priority: float = 0.0):
        """Add a new item or update the priority of an existing item"""
        if item in self.entry_finder:
            self.remove_item(item)
        self.counter += 1
        entry = Entry(priority, self.counter, item)
        self.entry_finder[item] = entry
        heappush(self.pq, entry)

    def remove_item(self, item: Item):
        """Mark an existing item as REMOVED.  Raise KeyError if not found."""
        entry = self.entry_finder.pop(item)
        entry.removed = True

    def pop(self):
        """Remove and return the lowest priority item. Raise KeyError if empty."""
        while self.pq:
            priority, count, item = heappop(self.pq)
            entry = heappop(self.pq)
            if not entry.removed:
                del self.entry_finder[entry.item]
                return item
        raise KeyError("pop from an empty priority queue")


if __name__ == "__main__":
    queue1 = PurePythonPriorityQueue()
    queue1.put((4, 5), 5.4)
    queue1.put((5, 6), 1.0)
    print(queue1.pop())  # Yay this works!

    queue2 = PriorityQueue()  # Nope
    queue2.put(Item(4, 5), 5.4)
    queue2.put(Item(5, 6), 1.0)
    print(queue2.pop())

这种类型的数据结构可以用numba实现吗?我当前的实施有什么问题?

完整跟踪:

(5, 6)
Traceback (most recent call last):
  File "/home/nicoco/src/work/work-research/scripts/thickness/priorityqueue.py", line 106, in <module>
    queue2 = PriorityQueue()  # Nope
  File "/home/nicoco/.cache/pypoetry/virtualenvs/work-research-r4deHn84-py3.8/lib/python3.8/site-packages/numba/experimental/jitclass/base.py", line 122, in __call__
    return cls._ctor(*bind.args[1:], **bind.kwargs)
  File "/home/nicoco/.cache/pypoetry/virtualenvs/work-research-r4deHn84-py3.8/lib/python3.8/site-packages/numba/core/dispatcher.py", line 420, in _compile_for_args
    error_rewrite(e, 'typing')
  File "/home/nicoco/.cache/pypoetry/virtualenvs/work-research-r4deHn84-py3.8/lib/python3.8/site-packages/numba/core/dispatcher.py", line 361, in error_rewrite
    raise e.with_traceback(None)
numba.core.errors.TypingError: Failed in nopython mode pipeline (step: nopython frontend)
Failed in nopython mode pipeline (step: nopython frontend)
- Resolution failure for literal arguments:
No implementation of function Function(<function typeddict_empty at 0x7fead8c3f8b0>) found for signature:

 >>> typeddict_empty(typeref[<class 'numba.core.types.containers.DictType'>], instance.jitclass.Item#7fead907c1f0<i:int64,j:int64>, instance.jitclass.Entry#7feb3119d3d0<priority:float64,count:int64,item:instance.jitclass.Item#7fead907c1f0<i:int64,j:int64>,removed:bool>)

There are 2 candidate implementations:
  - Of which 2 did not match due to:
  Overload in function 'typeddict_empty': File: numba/typed/typeddict.py: Line 213.
    With argument(s): '(typeref[<class 'numba.core.types.containers.DictType'>], instance.jitclass.Item#7fead907c1f0<i:int64,j:int64>, instance.jitclass.Entry#7feb3119d3d0<priority:float64,count:int64,item:instance.jitclass.Item#7fead907c1f0<i:int64,j:int64>,removed:bool>)':
   Rejected as the implementation raised a specific error:
     TypingError: Failed in nopython mode pipeline (step: nopython frontend)
   No implementation of function Function(<function new_dict at 0x7fead9002a60>) found for signature:

    >>> new_dict(instance.jitclass.Item#7fead907c1f0<i:int64,j:int64>, instance.jitclass.Entry#7feb3119d3d0<priority:float64,count:int64,item:instance.jitclass.Item#7fead907c1f0<i:int64,j:int64>,removed:bool>)

   There are 2 candidate implementations:
         - Of which 2 did not match due to:
         Overload in function 'impl_new_dict': File: numba/typed/dictobject.py: Line 639.
           With argument(s): '(instance.jitclass.Item#7fead907c1f0<i:int64,j:int64>, instance.jitclass.Entry#7feb3119d3d0<priority:float64,count:int64,item:instance.jitclass.Item#7fead907c1f0<i:int64,j:int64>,removed:bool>)':
          Rejected as the implementation raised a specific error:
            TypingError: Failed in nopython mode pipeline (step: nopython mode backend)
          No implementation of function Function(<built-in function eq>) found for signature:

           >>> eq(instance.jitclass.Item#7fead907c1f0<i:int64,j:int64>, instance.jitclass.Item#7fead907c1f0<i:int64,j:int64>)

          There are 30 candidate implementations:
                - Of which 28 did not match due to:
                Overload of function 'eq': File: <numerous>: Line N/A.
                  With argument(s): '(instance.jitclass.Item#7fead907c1f0<i:int64,j:int64>, instance.jitclass.Item#7fead907c1f0<i:int64,j:int64>)':
                 No match.
                - Of which 2 did not match due to:
                Operator Overload in function 'eq': File: unknown: Line unknown.
                  With argument(s): '(instance.jitclass.Item#7fead907c1f0<i:int64,j:int64>, instance.jitclass.Item#7fead907c1f0<i:int64,j:int64>)':
                 No match for registered cases:
                  * (bool, bool) -> bool
                  * (int8, int8) -> bool
                  * (int16, int16) -> bool
                  * (int32, int32) -> bool
                  * (int64, int64) -> bool
                  * (uint8, uint8) -> bool
                  * (uint16, uint16) -> bool
                  * (uint32, uint32) -> bool
                  * (uint64, uint64) -> bool
                  * (float32, float32) -> bool
                  * (float64, float64) -> bool
                  * (complex64, complex64) -> bool
                  * (complex128, complex128) -> bool

          During: lowering "call_function.8 = call load_global.4(dp, load_deref.6, load_deref.7, func=load_global.4, args=[Var(dp, dictobject.py:653), Var(load_deref.6, dictobject.py:654), Var(load_deref.7, dictobject.py:654)], kws=(), vararg=None)" at /home/nicoco/.cache/pypoetry/virtualenvs/work-research-r4deHn84-py3.8/lib/python3.8/site-packages/numba/typed/dictobject.py (654)
     raised from /home/nicoco/.cache/pypoetry/virtualenvs/work-research-r4deHn84-py3.8/lib/python3.8/site-packages/numba/core/types/functions.py:229

   During: resolving callee type: Function(<function new_dict at 0x7fead9002a60>)
   During: typing of call at /home/nicoco/.cache/pypoetry/virtualenvs/work-research-r4deHn84-py3.8/lib/python3.8/site-packages/numba/typed/typeddict.py (219)


   File "../../../../../.cache/pypoetry/virtualenvs/work-research-r4deHn84-py3.8/lib/python3.8/site-packages/numba/typed/typeddict.py", line 219:
       def impl(cls, key_type, value_type):
           return dictobject.new_dict(key_type, value_type)
           ^

  raised from /home/nicoco/.cache/pypoetry/virtualenvs/work-research-r4deHn84-py3.8/lib/python3.8/site-packages/numba/core/typeinfer.py:1071

- Resolution failure for non-literal arguments:
None

During: resolving callee type: BoundFunction((<class 'numba.core.types.abstract.TypeRef'>, 'empty') for typeref[<class 'numba.core.types.containers.DictType'>])
During: typing of call at /home/nicoco/src/work/work-research/scripts/thickness/priorityqueue.py (72)


File "priorityqueue.py", line 72:
    def __init__(self):
        <source elided>
        self.pq = nb.typed.List.empty_list(Entry(0.0, 0, Item(0, 0)))
        self.entry_finder = nb.typed.Dict.empty(Item(0, 0), Entry(0, 0, Item(0, 0)))
        ^

During: resolving callee type: jitclass.PriorityQueue#7fead8ba2b20<pq:ListType[instance.jitclass.Entry#7feb3119d3d0<priority:float64,count:int64,item:instance.jitclass.Item#7fead907c1f0<i:int64,j:int64>,removed:bool>],entry_finder:DictType[instance.jitclass.Item#7fead907c1f0<i:int64,j:int64>,instance.jitclass.Entry#7feb3119d3d0<priority:float64,count:int64,item:instance.jitclass.Item#7fead907c1f0<i:int64,j:int64>,removed:bool>]<iv=None>,counter:int64>
During: typing of call at <string> (3)

During: resolving callee type: jitclass.PriorityQueue#7fead8ba2b20<pq:ListType[instance.jitclass.Entry#7feb3119d3d0<priority:float64,count:int64,item:instance.jitclass.Item#7fead907c1f0<i:int64,j:int64>,removed:bool>],entry_finder:DictType[instance.jitclass.Item#7fead907c1f0<i:int64,j:int64>,instance.jitclass.Entry#7feb3119d3d0<priority:float64,count:int64,item:instance.jitclass.Item#7fead907c1f0<i:int64,j:int64>,removed:bool>]<iv=None>,counter:int64>
During: typing of call at <string> (3)


File "<string>", line 3:
<source missing, REPL/exec in use?>


Process finished with exit code 1

由于 numba 中的几个问题,这是不可能的,但如果我理解正确的话,应该会在下一个版本 (0.55) 中修复。作为目前的解决方法,我可以通过编译 llvmlite 0.38.0dev0 和 numba 的主分支来让它工作。我不使用 conda,但通过这种方式获得 llvmlite 和 numba 的预发布显然更容易。

这是我的实现:

from heapq import heappush, heappop
from typing import List, Tuple, Dict, Any

import numba as nb
import numpy as np
from numba.experimental import jitclass


class UpdatablePriorityQueueEntry:
    def __init__(self, p: float, i: Any):
        self.priority = p
        self.item = i

    def __lt__(self, other: "UpdatablePriorityQueueEntry"):
        return self.priority < other.priority


class UpdatablePriorityQueue:
    def __init__(self):
        self.pq = []
        self.entries_priority = {}

    def put(self, item: Any, priority: float = 0.0):
        entry = UpdatablePriorityQueueEntry(priority, item)
        self.entries_priority[item] = priority
        heappush(self.pq, entry)

    def pop(self) -> Any:
        while self.pq:
            entry = heappop(self.pq)
            if entry.priority == self.entries_priority[entry.item]:
                self.entries_priority[entry.item] = np.inf
                return entry.item
        raise KeyError("pop from an empty priority queue")

    def clear(self):
        self.pq.clear()
        self.entries_priority.clear()


@jitclass
class PriorityQueueEntry(UpdatablePriorityQueueEntry):
    priority: float
    item: Tuple[int, int]

    def __init__(self, p: float, i: Tuple[int, int]):
        self.priority = p
        self.item = i


@jitclass
class UpdatablePriorityQueue(UpdatablePriorityQueue):
    pq: List[PriorityQueueEntry2d]
    entries_priority: Dict[Tuple[int, int], float]

    def __init__(self):
        self.pq = nb.typed.List.empty_list(PriorityQueueEntry2d(0.0, (0, 0)))
        self.entries_priority = nb.typed.Dict.empty((0, 0), 0.0)

    def put(self, item: Tuple[int, int], priority: float = 0.0):
        entry = PriorityQueueEntry2d(priority, item)
        self.entries_priority[item] = priority
        heappush(self.pq, entry)

我遇到了与自定义 class Entry 相关的类似问题。基本上 Numba 无法使用 __lt__(self, other) 来比较条目,并给了我一个 No implementation of function Function(< built-in function lt >) 错误。

所以我想到了以下内容。它适用于 Python 3.8 上的 Ubuntu 18.04 上的 Numba 0.55.1。诀窍是避免使用任何自定义 class 对象作为优先级队列项目的一部分,以避免上述错误。

from typing import List, Dict, Tuple 
from heapq import heappush, heappop
import numba as nb
from numba.experimental import jitclass

# priority, counter, item, removed
entry_def = (0.0, 0, (0,0), nb.typed.List([False]))
entry_type = nb.typeof(entry_def)

@jitclass
class PriorityQueue:
    # The following helps numba infer type of variable
    pq: List[entry_type]
    entry_finder: Dict[Tuple[int, int], entry_type]
    counter: int
    entry: entry_type

    def __init__(self):
        # Must declare types here see https://numba.pydata.org/numba-doc/dev/reference/pysupported.html
        self.pq = nb.typed.List.empty_list((0.0, 0, (0,0), nb.typed.List([False])))
        self.entry_finder = nb.typed.Dict.empty( (0, 0), (0.0, 0, (0,0), nb.typed.List([False])))
        self.counter = 0

    def put(self, item: Tuple[int, int], priority: float = 0.0):
        """Add a new item or update the priority of an existing item"""
        if item in self.entry_finder:
            # Mark duplicate item for deletion
            self.remove_item(item)
    
        self.counter += 1
        entry = (priority, self.counter, item, nb.typed.List([False]))
        self.entry_finder[item] = entry
        heappush(self.pq, entry)

    def remove_item(self, item: Tuple[int, int]):
        """Mark an existing item as REMOVED via True.  Raise KeyError if not found."""
        self.entry = self.entry_finder.pop(item)
        self.entry[3][0] = True
    
    def pop(self):
        """Remove and return the lowest priority item. Raise KeyError if empty."""
        while self.pq:
            priority, count, item, removed = heappop(self.pq)
            if not removed[0]:
                del self.entry_finder[item]
                return priority, item
        raise KeyError("pop from an empty priority queue")

首先定义一个名为entry_def 的全局变量,它将作为优先级队列pq 中的条目。 “已删除”标记现在已替换为 numba.typed.List([False]),以便在优先级键更改(延迟删除)的情况下跟踪要删除的项目。烦人的部分是必须输入 pqentry_finder 的定义;我无法重用 entry_def 变量。

我可以确认 PriorityQueue 的工作方式如下:

    q = PriorityQueue()
    q.put((1,1), 5.0)
    q.put((1,1), 4.0)
    q.put((1,1), 3.0)
    q.put((1,1), 6.0)
    print(q.pq)
    >>  [(3.0, 3, (1, 1), ListType[bool]([True])), (5.0, 1, (1, 1), ListType[bool]([True])), (4.0, 2, (1, 1), ListType[bool]([True])), (6.0, 4, (1, 1), ListType[bool]([False]))]
    print(q.pop())
    >> (6.0, (1, 1))
    print(len(q.entry_finder))
    >> 0

希望有人会觉得这有用或可以提供更好的选择。