防止用户定义数据结构的不当使用,并使DS可恢复



目标:

  • 定义使用最大堆实现的优先级队列
  • 提供尽可能多的类型安全,并提供灵活/可重用的实现(即使python是动态的,并且只有类型安全是暗示的)

我查看了网站中的多个实现,人们似乎只是认为int和float将是存储在Max PQ 中的唯一元素

如果我们想跟踪,比如说按年龄的Person对象,或者按金额的Transactions。如果用户插入他们想要的任何类型,我看到的所有实现都将在运行时失败。

理想情况下,我想:

  1. 允许用户将此重新用于其他数据类型的PQ实现
  2. 如果使用不当(例如插入实现不知道如何比较的类的实例),则会快速失败

为了实现PQ,我们需要能够使用一些相等运算符来比较对象,假设我们只需要>>=

经过一些研究,我发现用户定义的类可以实现,我相信这将允许更大的灵活性:

==  __eq__
!=  __ne__
<   __lt__
<=  __le__
>   __gt__
>=  __ge__

那么,我是否可以在构造函数中进行检查,以确保我需要的相等方法存在,如果不存在,则抛出异常?如果我错误地接近了这个,我应该探索其他什么路线?


Barebone代码:


from typing import TypeVar, Generic, List
#define what types our generic class expects, ideally only classes that conform to an interface (define methods needed for proper comparison of variety of classes)
T = TypeVar("T", int, float)

class MaxHeapPriorityQueue(Generic[T]):
def __init__(self):
## check if __gt__, __ge__, etc, are defined in object type T, if all or some that we need are missing, raise Exception
self._heap: List[T] = []
self._insert_pointer: int = 0
def insert(self, value: T) -> None:
# TODO IMPLEMENT
def delete_max(self) -> T:
##TODO implement
def __trickle_up(self, node_index: int) -> None:
parent_index = self.__calculate_parent_node_index(node_index)
## item to item comparison which may fail or lead to logic bugs if user stored non numerical values in Heap
while node_index > 1 and self._heap[node_index] > self._heap[parent_index]:
self.__exchange(node_index, parent_index)
node_index = parent_index
parent_index = self.__calculate_parent_node_index(node_index)
@staticmethod
def __calculate_parent_node_index(child_node_index: int) -> int:
return child_node_index // 2
def __exchange(self, node_index_1: int, node_index_2: int) -> None:
## TODO implement

使用协议编辑,mypy检查似乎有效,但键入module和if not isinstance(T, SupportsComparison): raise TypeError('can not instantiate with that type')不会引发异常,并且如果检查,则不会进入执行

通用DS:

from typing import TypeVar, Generic, List, Protocol, runtime_checkable

@runtime_checkable
class SupportsComparison(Protocol):
def __lt__(self, other) -> bool: ...
def __le__(self, other) -> bool: ...
def __eq__(self, other) -> bool: ...
def __ne__(self, other) -> bool: ...
def __ge__(self, other) -> bool: ...
def __gt__(self, other) -> bool: ...

T = TypeVar("T", bound=SupportsComparison)

class MaxHeapPriorityQueue(Generic[T]):
def __init__(self):
if not isinstance(T, SupportsComparison):
raise TypeError('can not instantiate with that type')
self._heap: List[T] = []
# pointer which will add elements in position such that we will always have a complete binary tree. It will
# point to latest point added
self._insert_pointer: int = 0
def insert(self, value: T) -> None:
# we increment before inserting because pointer is not pointing where next element should be added,
# it instead points to one less. SO if we have 1 element, it will point to 1. If we have 0 elements it will
# point to 0.
self._insert_pointer += 1
self._heap.insert(self._insert_pointer, value)
self.__trickle_up(self._insert_pointer)
def delete_max(self) -> T:
if self._insert_pointer == 0:
raise Exception("Can not remove when PQ is empty")
return self._heap[1]
##TODO implement
def __trickle_up(self, node_index: int) -> None:
parent_index = self.__calculate_parent_node_index(node_index)
# we want to stop trickling up if we have reached the root of the binary tree or the node we are trickling up
# is less than parent
while node_index > 1 and self._heap[node_index] > self._heap[parent_index]:
self.__exchange(node_index, parent_index)
node_index = parent_index
parent_index = self.__calculate_parent_node_index(node_index)
@staticmethod
def __calculate_parent_node_index(child_node_index: int) -> int:
return child_node_index // 2
def __exchange(self, node_index_1: int, node_index_2: int) -> None:
temp = self._heap[node_index_1]
self._heap[node_index_1] = self._heap[node_index_2]
self._heap[node_index_2] = temp

实例化:

# Press the green button in the gutter to run the script.
if __name__ == '__main__':
max_pq = MaxHeapPriorityQueue[Person]()
class Person:
def __init__(self, name, age):
self.name = name
self.age = age
def __lt__(self, other) -> bool:
return True  #TODO IMPLEMENT, THIS IS JUST A TEST
def __le__(self, other) -> bool:
return True  #TODO IMPLEMENT, THIS IS JUST A TEST
def __eq__(self, other) -> bool:
return True  #TODO IMPLEMENT, THIS IS JUST A TEST
def __ne__(self, other) -> bool:
return True  #TODO IMPLEMENT, THIS IS JUST A TEST
def __gt__(self, other) -> bool:
return True #TODO IMPLEMENT, THIS IS JUST A TEST
def __ge__(self, other) -> bool:
return True  #TODO IMPLEMENT, THIS IS JUST A TEST

class Animal:
def __init__(self, breed):
self.breed = breed


执行检查:

if __name__ == '__main__':
max_pq = MaxHeapPriorityQueue[Person]()  ## passes mypy check
max_pq2 = MaxHeapPriorityQueue[Animal]()  ## fails mypy check

定义一个需要定义__ge____gt__Protocol

@typing.runtime_checkable
class SupportsComparison(typing.Protocol):
def __ge__(self, other) -> bool:
...
def __gt__(self, other) -> bool:
...
T = typing.TypeVar("T", bound=SupportsComparison)

class MaxHeapPriorityQueue(Generic[T]):
...

相关内容

最新更新