如何实现numba jit优先队列?



我无法实现numba jned优先级队列。

大量抄袭python文档,我对这门课相当满意。

import itertools
import numba as nb
from numba.experimental import jitclass
from typing import List, Tuple, Dict
from heapq import heappush, heappop

class PurePythonPriorityQueue:
def __init__(self):
self.pq = [] # list of entries arranged in a heap
self.entry_finder = {}  # mapping of indices to entries
self.REMOVED = -1 # placeholder for a removed item
self.counter = itertools.count() # unique sequence count
def put(self, item: Tuple[int, int], priority: float = 0.0):
"""Add a new item or update the priority of an existing item"""
if item in self.entry_finder:
self.remove_item(item)
count = next(self.counter)
entry = [priority, count, item]
self.entry_finder[item] = entry
heappush(self.pq, entry)
def remove_item(self, item: Tuple[int, int]):
"""Mark an existing item as REMOVED.  Raise KeyError if not found."""
entry = self.entry_finder.pop(item)
entry[-1] = self.REMOVED
def pop(self):
"""Remove and return the lowest priority item. Raise KeyError if empty."""
while self.pq:
priority, count, item = heappop(self.pq)
if item is not self.REMOVED:
del self.entry_finder[item]
return item
raise KeyError("pop from an empty priority queue")

现在我想从做大量数值工作的numba jit函数中调用它,所以我试着把它变成一个numba jit类。由于条目在普通python实现中是异构列表,因此我认为我也应该实现其他jitclass。然而,我得到了一个Failed in nopython mode pipeline (step: nopython frontend)(完整的跟踪如下)。

这是我的尝试:

@jitclass
class Item:
i: int
j: int
def __init__(self, i, j):
self.i = i
self.j = j

@jitclass
class Entry:
priority: float
count: int
item: Item
removed: bool
def __init__(self, p: float, c: int, i: Item):
self.priority = p
self.count = c
self.item = i
self.removed = False

@jitclass
class PriorityQueue:
pq: List[Entry]
entry_finder: Dict[Item, Entry]
counter: int
def __init__(self):
self.pq = nb.typed.List.empty_list(Entry(0.0, 0, Item(0, 0)))
self.entry_finder = nb.typed.Dict.empty(Item(0, 0), Entry(0, 0, Item(0, 0)))
self.counter = 0
def put(self, item: Item, priority: float = 0.0):
"""Add a new item or update the priority of an existing item"""
if item in self.entry_finder:
self.remove_item(item)
self.counter += 1
entry = Entry(priority, self.counter, item)
self.entry_finder[item] = entry
heappush(self.pq, entry)
def remove_item(self, item: Item):
"""Mark an existing item as REMOVED.  Raise KeyError if not found."""
entry = self.entry_finder.pop(item)
entry.removed = True
def pop(self):
"""Remove and return the lowest priority item. Raise KeyError if empty."""
while self.pq:
priority, count, item = heappop(self.pq)
entry = heappop(self.pq)
if not entry.removed:
del self.entry_finder[entry.item]
return item
raise KeyError("pop from an empty priority queue")

if __name__ == "__main__":
queue1 = PurePythonPriorityQueue()
queue1.put((4, 5), 5.4)
queue1.put((5, 6), 1.0)
print(queue1.pop())  # Yay this works!
queue2 = PriorityQueue()  # Nope
queue2.put(Item(4, 5), 5.4)
queue2.put(Item(5, 6), 1.0)
print(queue2.pop())

这种类型的数据结构可以用numba实现吗?我当前的实现有什么问题?

全面跟踪:

(5, 6)
Traceback (most recent call last):
File "/home/nicoco/src/work/work-research/scripts/thickness/priorityqueue.py", line 106, in <module>
queue2 = PriorityQueue()  # Nope
File "/home/nicoco/.cache/pypoetry/virtualenvs/work-research-r4deHn84-py3.8/lib/python3.8/site-packages/numba/experimental/jitclass/base.py", line 122, in __call__
return cls._ctor(*bind.args[1:], **bind.kwargs)
File "/home/nicoco/.cache/pypoetry/virtualenvs/work-research-r4deHn84-py3.8/lib/python3.8/site-packages/numba/core/dispatcher.py", line 420, in _compile_for_args
error_rewrite(e, 'typing')
File "/home/nicoco/.cache/pypoetry/virtualenvs/work-research-r4deHn84-py3.8/lib/python3.8/site-packages/numba/core/dispatcher.py", line 361, in error_rewrite
raise e.with_traceback(None)
numba.core.errors.TypingError: Failed in nopython mode pipeline (step: nopython frontend)
Failed in nopython mode pipeline (step: nopython frontend)
- Resolution failure for literal arguments:
No implementation of function Function(<function typeddict_empty at 0x7fead8c3f8b0>) found for signature:
>>> typeddict_empty(typeref[<class 'numba.core.types.containers.DictType'>], instance.jitclass.Item#7fead907c1f0<i:int64,j:int64>, instance.jitclass.Entry#7feb3119d3d0<priority:float64,count:int64,item:instance.jitclass.Item#7fead907c1f0<i:int64,j:int64>,removed:bool>)
There are 2 candidate implementations:
- Of which 2 did not match due to:
Overload in function 'typeddict_empty': File: numba/typed/typeddict.py: Line 213.
With argument(s): '(typeref[<class 'numba.core.types.containers.DictType'>], instance.jitclass.Item#7fead907c1f0<i:int64,j:int64>, instance.jitclass.Entry#7feb3119d3d0<priority:float64,count:int64,item:instance.jitclass.Item#7fead907c1f0<i:int64,j:int64>,removed:bool>)':
Rejected as the implementation raised a specific error:
TypingError: Failed in nopython mode pipeline (step: nopython frontend)
No implementation of function Function(<function new_dict at 0x7fead9002a60>) found for signature:
>>> new_dict(instance.jitclass.Item#7fead907c1f0<i:int64,j:int64>, instance.jitclass.Entry#7feb3119d3d0<priority:float64,count:int64,item:instance.jitclass.Item#7fead907c1f0<i:int64,j:int64>,removed:bool>)
There are 2 candidate implementations:
- Of which 2 did not match due to:
Overload in function 'impl_new_dict': File: numba/typed/dictobject.py: Line 639.
With argument(s): '(instance.jitclass.Item#7fead907c1f0<i:int64,j:int64>, instance.jitclass.Entry#7feb3119d3d0<priority:float64,count:int64,item:instance.jitclass.Item#7fead907c1f0<i:int64,j:int64>,removed:bool>)':
Rejected as the implementation raised a specific error:
TypingError: Failed in nopython mode pipeline (step: nopython mode backend)
No implementation of function Function(<built-in function eq>) found for signature:
>>> eq(instance.jitclass.Item#7fead907c1f0<i:int64,j:int64>, instance.jitclass.Item#7fead907c1f0<i:int64,j:int64>)
There are 30 candidate implementations:
- Of which 28 did not match due to:
Overload of function 'eq': File: <numerous>: Line N/A.
With argument(s): '(instance.jitclass.Item#7fead907c1f0<i:int64,j:int64>, instance.jitclass.Item#7fead907c1f0<i:int64,j:int64>)':
No match.
- Of which 2 did not match due to:
Operator Overload in function 'eq': File: unknown: Line unknown.
With argument(s): '(instance.jitclass.Item#7fead907c1f0<i:int64,j:int64>, instance.jitclass.Item#7fead907c1f0<i:int64,j:int64>)':
No match for registered cases:
* (bool, bool) -> bool
* (int8, int8) -> bool
* (int16, int16) -> bool
* (int32, int32) -> bool
* (int64, int64) -> bool
* (uint8, uint8) -> bool
* (uint16, uint16) -> bool
* (uint32, uint32) -> bool
* (uint64, uint64) -> bool
* (float32, float32) -> bool
* (float64, float64) -> bool
* (complex64, complex64) -> bool
* (complex128, complex128) -> bool
During: lowering "$20call_function.8 = call $12load_global.4(dp, $16load_deref.6, $18load_deref.7, func=$12load_global.4, args=[Var(dp, dictobject.py:653), Var($16load_deref.6, dictobject.py:654), Var($18load_deref.7, dictobject.py:654)], kws=(), vararg=None)" at /home/nicoco/.cache/pypoetry/virtualenvs/work-research-r4deHn84-py3.8/lib/python3.8/site-packages/numba/typed/dictobject.py (654)
raised from /home/nicoco/.cache/pypoetry/virtualenvs/work-research-r4deHn84-py3.8/lib/python3.8/site-packages/numba/core/types/functions.py:229
During: resolving callee type: Function(<function new_dict at 0x7fead9002a60>)
During: typing of call at /home/nicoco/.cache/pypoetry/virtualenvs/work-research-r4deHn84-py3.8/lib/python3.8/site-packages/numba/typed/typeddict.py (219)

File "../../../../../.cache/pypoetry/virtualenvs/work-research-r4deHn84-py3.8/lib/python3.8/site-packages/numba/typed/typeddict.py", line 219:
def impl(cls, key_type, value_type):
return dictobject.new_dict(key_type, value_type)
^
raised from /home/nicoco/.cache/pypoetry/virtualenvs/work-research-r4deHn84-py3.8/lib/python3.8/site-packages/numba/core/typeinfer.py:1071
- Resolution failure for non-literal arguments:
None
During: resolving callee type: BoundFunction((<class 'numba.core.types.abstract.TypeRef'>, 'empty') for typeref[<class 'numba.core.types.containers.DictType'>])
During: typing of call at /home/nicoco/src/work/work-research/scripts/thickness/priorityqueue.py (72)

File "priorityqueue.py", line 72:
def __init__(self):
<source elided>
self.pq = nb.typed.List.empty_list(Entry(0.0, 0, Item(0, 0)))
self.entry_finder = nb.typed.Dict.empty(Item(0, 0), Entry(0, 0, Item(0, 0)))
^
During: resolving callee type: jitclass.PriorityQueue#7fead8ba2b20<pq:ListType[instance.jitclass.Entry#7feb3119d3d0<priority:float64,count:int64,item:instance.jitclass.Item#7fead907c1f0<i:int64,j:int64>,removed:bool>],entry_finder:DictType[instance.jitclass.Item#7fead907c1f0<i:int64,j:int64>,instance.jitclass.Entry#7feb3119d3d0<priority:float64,count:int64,item:instance.jitclass.Item#7fead907c1f0<i:int64,j:int64>,removed:bool>]<iv=None>,counter:int64>
During: typing of call at <string> (3)
During: resolving callee type: jitclass.PriorityQueue#7fead8ba2b20<pq:ListType[instance.jitclass.Entry#7feb3119d3d0<priority:float64,count:int64,item:instance.jitclass.Item#7fead907c1f0<i:int64,j:int64>,removed:bool>],entry_finder:DictType[instance.jitclass.Item#7fead907c1f0<i:int64,j:int64>,instance.jitclass.Entry#7feb3119d3d0<priority:float64,count:int64,item:instance.jitclass.Item#7fead907c1f0<i:int64,j:int64>,removed:bool>]<iv=None>,counter:int64>
During: typing of call at <string> (3)

File "<string>", line 3:
<source missing, REPL/exec in use?>

Process finished with exit code 1

由于numba中的几个问题,这是不可能的,但如果我理解正确的话,应该在下一个版本(0.55)中修复。作为目前的解决方案,我可以通过编译llvmlite 0.38.0dev0和numba的主分支来让它工作。我不使用conda,但显然这样更容易获得llvmlite和numba的预发布版。

这是我的实现:

from heapq import heappush, heappop
from typing import List, Tuple, Dict, Any
import numba as nb
import numpy as np
from numba.experimental import jitclass

class UpdatablePriorityQueueEntry:
def __init__(self, p: float, i: Any):
self.priority = p
self.item = i
def __lt__(self, other: "UpdatablePriorityQueueEntry"):
return self.priority < other.priority

class UpdatablePriorityQueue:
def __init__(self):
self.pq = []
self.entries_priority = {}
def put(self, item: Any, priority: float = 0.0):
entry = UpdatablePriorityQueueEntry(priority, item)
self.entries_priority[item] = priority
heappush(self.pq, entry)
def pop(self) -> Any:
while self.pq:
entry = heappop(self.pq)
if entry.priority == self.entries_priority[entry.item]:
self.entries_priority[entry.item] = np.inf
return entry.item
raise KeyError("pop from an empty priority queue")
def clear(self):
self.pq.clear()
self.entries_priority.clear()

@jitclass
class PriorityQueueEntry(UpdatablePriorityQueueEntry):
priority: float
item: Tuple[int, int]
def __init__(self, p: float, i: Tuple[int, int]):
self.priority = p
self.item = i

@jitclass
class UpdatablePriorityQueue(UpdatablePriorityQueue):
pq: List[PriorityQueueEntry2d]
entries_priority: Dict[Tuple[int, int], float]
def __init__(self):
self.pq = nb.typed.List.empty_list(PriorityQueueEntry2d(0.0, (0, 0)))
self.entries_priority = nb.typed.Dict.empty((0, 0), 0.0)
def put(self, item: Tuple[int, int], priority: float = 0.0):
entry = PriorityQueueEntry2d(priority, item)
self.entries_priority[item] = priority
heappush(self.pq, entry)

我有一个与自定义类Entry相关的类似问题. 基本上Numba无法使用__lt__(self, other)来比较条目,并给了我一个No implementation of function Function(< built-in function lt >)错误。

所以我想到了下面这些。它在Numba 0.55.1, Python 3.8和Ubuntu 18.04上工作。窍门是避免使用任何自定义类对象作为优先级队列项的一部分,以避免上述错误。

from typing import List, Dict, Tuple 
from heapq import heappush, heappop
import numba as nb
from numba.experimental import jitclass
# priority, counter, item, removed
entry_def = (0.0, 0, (0,0), nb.typed.List([False]))
entry_type = nb.typeof(entry_def)
@jitclass
class PriorityQueue:
# The following helps numba infer type of variable
pq: List[entry_type]
entry_finder: Dict[Tuple[int, int], entry_type]
counter: int
entry: entry_type
def __init__(self):
# Must declare types here see https://numba.pydata.org/numba-doc/dev/reference/pysupported.html
self.pq = nb.typed.List.empty_list((0.0, 0, (0,0), nb.typed.List([False])))
self.entry_finder = nb.typed.Dict.empty( (0, 0), (0.0, 0, (0,0), nb.typed.List([False])))
self.counter = 0
def put(self, item: Tuple[int, int], priority: float = 0.0):
"""Add a new item or update the priority of an existing item"""
if item in self.entry_finder:
# Mark duplicate item for deletion
self.remove_item(item)

self.counter += 1
entry = (priority, self.counter, item, nb.typed.List([False]))
self.entry_finder[item] = entry
heappush(self.pq, entry)
def remove_item(self, item: Tuple[int, int]):
"""Mark an existing item as REMOVED via True.  Raise KeyError if not found."""
self.entry = self.entry_finder.pop(item)
self.entry[3][0] = True

def pop(self):
"""Remove and return the lowest priority item. Raise KeyError if empty."""
while self.pq:
priority, count, item, removed = heappop(self.pq)
if not removed[0]:
del self.entry_finder[item]
return priority, item
raise KeyError("pop from an empty priority queue")

首先定义一个名为entry_def的全局变量,它将作为优先级队列pq中的条目。"removed"sentinel现在被numba.typed.List([False])取代,以便在优先级键更改(延迟删除)的情况下跟踪要删除的项。恼人的部分是必须输入pqentry_finder的定义;我不能重用entry_def变量。

我可以确认PriorityQueue的工作如下:

q = PriorityQueue()
q.put((1,1), 5.0)
q.put((1,1), 4.0)
q.put((1,1), 3.0)
q.put((1,1), 6.0)
print(q.pq)
>>  [(3.0, 3, (1, 1), ListType[bool]([True])), (5.0, 1, (1, 1), ListType[bool]([True])), (4.0, 2, (1, 1), ListType[bool]([True])), (6.0, 4, (1, 1), ListType[bool]([False]))]
print(q.pop())
>> (6.0, (1, 1))
print(len(q.entry_finder))
>> 0

希望有人会觉得这有用,或者可以提供更好的替代方案。

最新更新