我现在正在努力寻找具有最大总和的根到叶路径。我的方法是:
def max_sum(root):
_max = 0
find_max(root, _max, 0)
return _max
def find_max(node, max_sum, current_sum):
if not node:
return 0
current_sum += node.value
if not node.left and not node.right:
print(current_sum, max_sum, current_sum > max_sum)
max_sum = max(max_sum, current_sum)
if node.left:
find_max(node.left, max_sum, current_sum)
if node.right:
find_max(node.right, max_sum, current_sum)
current_sum -= node.value
class TreeNode():
def __init__(self, _value):
self.value = _value
self.left, self.right, self.next = None, None, None
def main():
root = TreeNode(1)
root.left = TreeNode(7)
root.right = TreeNode(9)
root.left.left = TreeNode(4)
root.left.right = TreeNode(5)
root.right.left = TreeNode(2)
root.right.right = TreeNode(7)
print(max_sum(root))
root = TreeNode(12)
root.left = TreeNode(7)
root.right = TreeNode(1)
root.left.left = TreeNode(4)
root.right.left = TreeNode(10)
root.right.right = TreeNode(5)
print(max_sum(root))
main()
带输出:
12 0 True
13 0 True
12 0 True
17 0 True
0
23 0 True
23 0 True
18 0 True
0
Process finished with exit code 0
预期输出为 17 和 23。
我想确认为什么我的方法无法比较max_sum
和current_sum
?即使它在比较中返回了 true,但不会更新max_sum
。感谢您的帮助。
错误修复
以下是我们可以修复您的find_sum
函数的方法 -
def find_max(node, current_sum = 0):
# empty tree
if not node:
return current_sum
# branch
elif node.left or node.right:
next_sum = current_sum + node.value
left = find_max(node.left, next_sum)
right = find_max(node.right, next_sum)
return max(left, right)
# leaf
else:
return current_sum + node.value
t1 = TreeNode
( 1
, TreeNode(7, TreeNode(4), TreeNode(5))
, TreeNode(9, TreeNode(2), TreeNode(7))
)
t2 = TreeNode
( 12
, TreeNode(7, TreeNode(4), None)
, TreeNode(1, TreeNode(10), TreeNode(5))
)
print(find_max(t1))
print(find_max(t2))
17
23
<小时 />看到过程
我们可以通过跟踪其中一个示例来可视化计算过程,find_max(t2)
-
12
/
7 1
/ /
4 None 10 5
find_max(12,0)
/
7 1
/ /
4 None 10 5
find_max(12,0)
/
max(find_max(7,12), find_max(1,12))
/ /
4 None 10 5
find_max(12,0)
/
max(find_max(7,12), find_max(1,12))
/ /
max(find_max(4,19), find_max(None,19)) max(find_max(10,13), find_max(5,13))
find_max(12,0)
/
max(find_max(7,12), find_max(1,12))
/ /
max(23, 19) max(23, 18)
find_max(12,0)
/
max(find_max(7,12), find_max(1,12))
| |
23 23
find_max(12,0)
/
max(23, 23)
find_max(12,0)
|
23
23
<小时 />细化
但是我认为我们可以改进。就像我们在你之前的问题中所做的那样,我们可以再次使用数学归纳——
- 如果输入树
t
为空,则返回空结果 - (感应)
t
不为空。 如果存在子问题t.left
或t.right
分支,请将t.value
添加到累积结果r
中,并在每个分支上重复 - (归纳)
t
不为空,t.left
和t.right
均为空;已到达叶节点;将t.value
添加到累积结果r
并得到总和
def sum_branch (t, r = 0):
if not t:
return # (1)
elif t.left or t.right:
yield from sum_branch(t.left, r + t.value) # (2)
yield from sum_branch(t.right, r + t.value)
else:
yield r + t.value # (3)
t1 = TreeNode
( 1
, TreeNode(7, TreeNode(4), TreeNode(5))
, TreeNode(9, TreeNode(2), TreeNode(7))
)
t2 = TreeNode
( 12
, TreeNode(7, TreeNode(4), None)
, TreeNode(1, TreeNode(10), TreeNode(5))
)
print(max(sum_branch(t1)))
print(max(sum_branch(t2)))
17
23
<小时 />泛型
也许写这个问题的一个更有趣的方法是先写一个泛型paths
函数——
def paths (t, p = []):
if not t:
return # (1)
elif t.left or t.right:
yield from paths(t.left, [*p, t.value]) # (2)
yield from paths(t.right, [*p, t.value])
else:
yield [*p, t.value] # (3)
然后我们可以将最大和问题作为泛型函数max
、sum
和paths
的组合来解决 -
print(max(sum(x) for x in paths(t1)))
print(max(sum(x) for x in paths(t2)))
17
23