给定一个图,其根节点由一个Node
对象定义:
class Node:
def __init__(self, val = 0, neighbors = None):
self.val = val
self.neighbors = neighbors if neighbors is not None else []
def __eq__(self, other) -> bool:
if isinstance(other, self.__class__):
return self.__repr__() == other.__repr__()
def __repr__(self):
return f"Node(val: {self.val}, neighbors: {self.neighbors})"
def __str__(self):
return self.__repr__()
Graph类定义如下,它使用上面的Node
类从邻接表构造自己
class Graph:
def __init__(self, adj_list=[]):
self.root = adj_list or self.make_graph(adj_list)
def __repr__(self):
return str(self.root)
def __str__(self):
return self.__repr__()
def __eq__(self, other):
if isinstance(other, self.__class__):
return other.root == self.root
return False
def make_graph(self, adj_list) -> Node:
# Ref: https://stackoverflow.com/a/72499884/16378872
nodes = [Node(i + 1) for i in range(len(adj_list))]
for i, neighbors in enumerate(adj_list):
nodes[i].neighbors = [nodes[j-1] for j in neighbors]
return nodes[0]
例如,将邻接表[[2,4],[1,3],[2,4],[1,3]]
转换为Graph
,如下所示
graph = Graph([[2,4],[1,3],[2,4],[1,3]])
print(graph)
Node(val: 1, neighbors: [Node(val: 2, neighbors: [Node(val: 1, neighbors: [...]), Node(val: 3, neighbors: [Node(val: 2, neighbors: [...]), Node(val: 4, neighbors: [Node(val: 1, neighbors: [...]), Node(val: 3, neighbors: [...])])])]), Node(val: 4, neighbors: [Node(val: 1, neighbors: [...]), Node(val: 3, neighbors: [Node(val: 2, neighbors: [Node(val: 1, neighbors: [...]), Node(val: 3, neighbors: [...])]), Node(val: 4, neighbors: [...])])])])
现在如果我有两个图:
graph1 = Graph([[2,4],[1,3],[2,4],[1,3]])
graph2 = Graph([[2,4],[1,3],[2,4],[1,3]])
print(graph1 == graph2)
True
我可以通过比较graph1
和graph2
这两个图对象Node.__repr__()
的返回值来检查它们的相等性,这实际上是通过Graph
的__eq__()
特殊方法来完成的,即比较两个图的根节点的相等性,因此使用Node
的__repr__()
特殊方法,如上文所述。
__repr__
方法将输出中的深度嵌套邻居截断为[...]
,但可能存在一些深层节点的Node.val
值不相等,从而使该方法的比较结果不可靠。
我关心的是,是否有更好更可靠的方法来做这个相等性测试,而不仅仅是比较两个图的根节点的__repr__()
?
您可以实现深度优先遍历并按值和度比较节点。将节点标记为已访问,以避免第二次遍历它们:
def __eq__(self, other) -> bool:
visited = set()
def dfs(a, b):
if a.val != b.val or len(a.neighbors) != len(b.neighbors):
return False
if a.val in visited:
return True
visited.add(a.val)
return all(dfs(*pair) for pair in zip(a.neighbors, b.neighbors))
return isinstance(other, self.__class__) and dfs(self, other)
这段代码假设一个节点的值唯一地标识同一个图中的节点。
这也假定图是连通的,否则与根节点断开的组件将不会被比较。