判断两个图是否相同的算法



给定一个图,其根节点由一个Node对象定义:

class Node:
def __init__(self, val = 0, neighbors = None):
self.val = val
self.neighbors = neighbors if neighbors is not None else []

def __eq__(self, other) -> bool:

if isinstance(other, self.__class__):
return self.__repr__() == other.__repr__()

def __repr__(self):
return f"Node(val: {self.val}, neighbors: {self.neighbors})"

def __str__(self):
return self.__repr__()

Graph类定义如下,它使用上面的Node类从邻接表构造自己

class Graph:

def __init__(self, adj_list=[]):

self.root = adj_list or self.make_graph(adj_list)


def __repr__(self):
return str(self.root)

def __str__(self):
return self.__repr__()

def __eq__(self, other):

if isinstance(other, self.__class__):
return other.root == self.root

return False


def make_graph(self, adj_list) -> Node:
# Ref: https://stackoverflow.com/a/72499884/16378872
nodes = [Node(i + 1) for i in range(len(adj_list))]

for i, neighbors in enumerate(adj_list):
nodes[i].neighbors = [nodes[j-1] for j in neighbors]

return nodes[0]

例如,将邻接表[[2,4],[1,3],[2,4],[1,3]]转换为Graph,如下所示

graph = Graph([[2,4],[1,3],[2,4],[1,3]])
print(graph)
Node(val: 1, neighbors: [Node(val: 2, neighbors: [Node(val: 1, neighbors: [...]), Node(val: 3, neighbors: [Node(val: 2, neighbors: [...]), Node(val: 4, neighbors: [Node(val: 1, neighbors: [...]), Node(val: 3, neighbors: [...])])])]), Node(val: 4, neighbors: [Node(val: 1, neighbors: [...]), Node(val: 3, neighbors: [Node(val: 2, neighbors: [Node(val: 1, neighbors: [...]), Node(val: 3, neighbors: [...])]), Node(val: 4, neighbors: [...])])])])

现在如果我有两个图:

graph1 = Graph([[2,4],[1,3],[2,4],[1,3]])
graph2 = Graph([[2,4],[1,3],[2,4],[1,3]])
print(graph1 == graph2)
True

我可以通过比较graph1graph2这两个图对象Node.__repr__()的返回值来检查它们的相等性,这实际上是通过Graph__eq__()特殊方法来完成的,即比较两个图的根节点的相等性,因此使用Node__repr__()特殊方法,如上文所述。

__repr__方法将输出中的深度嵌套邻居截断为[...],但可能存在一些深层节点的Node.val值不相等,从而使该方法的比较结果不可靠。

我关心的是,是否有更好更可靠的方法来做这个相等性测试,而不仅仅是比较两个图的根节点的__repr__()?

您可以实现深度优先遍历并按值和度比较节点。将节点标记为已访问,以避免第二次遍历它们:

def __eq__(self, other) -> bool:
visited = set()

def dfs(a, b):
if a.val != b.val or len(a.neighbors) != len(b.neighbors):
return False
if a.val in visited:
return True
visited.add(a.val)
return all(dfs(*pair) for pair in zip(a.neighbors, b.neighbors))

return isinstance(other, self.__class__) and dfs(self, other)

这段代码假设一个节点的值唯一地标识同一个图中的节点

这也假定图是连通的,否则与根节点断开的组件将不会被比较。

相关内容

  • 没有找到相关文章

最新更新