问题是在每个父子对被单位距离分隔的情况下,找出BinarySearchTree中每两个节点之间的距离之和。每次插入后都要计算。
,
->first node is inserted..
(root)
total sum=0;
->left and right node are inserted
(root)
/
(left) (right)
total sum = distance(root,left)+distance(root,right)+distance(left,right);
= 1 + 1 + 2
= 4
and so on.....
我想到的解决方案:
蛮力。步骤:
- 执行DFS并跟踪所有节点:
O(n)
. - 选取每两个节点,用最低共同祖先法计算它们之间的
O(nC2)_times_O(log(n))=O(n2log(n))
距离并相加。
总体复杂度:
-O(n2log(n))
。- 执行DFS并跟踪所有节点:
O(nlog(n))
。步骤:-- 插入前执行DFS并跟踪所有节点:
O(n)
. - 计算插入节点到:
O(nlog(n))
的距离。 - 将现有的总和与步骤2中计算的总和相加
总体复杂度:
-O(nlog(n))
.- 插入前执行DFS并跟踪所有节点:
现在的问题是"是否存在O(n)
阶的解??"
我们可以通过遍历树两次来实现。
首先,我们需要三个数组 int []left
,其中存储了左子树的距离和。
int []right
,其中存储了右子树的距离和。
int []up
,存储父树(不含当前子树)的距离之和。
那么,第一次遍历,对于每个节点,我们计算左和右距离。如果节点是叶节点,则返回0,如果不是,则可以得到如下公式:
int cal(Node node){
int left = cal(node.left);
int right = cal(node.right);
left[node.index] = left;
right[node.index] = right;
//Depend on the current node have left or right node, we add 0,1 or 2 to the final result
int add = (node.left != null && node.right != null)? 2 : node.left != null ? 1 : node.right != null ? 1 : 0;
return left + right + add;
}
对于第二次遍历,我们需要在每个节点上加上到父节点的总距离。
1
/
2 3
/
4 5
例如节点1(根),总距离为left[1] + right[1] + 2
, up[1] = 0
;(我们加上2,因为根有左右子树,确切的公式是:
int add = 0;
if (node.left != null)
add++;
if(node.right != null)
add++;
对于节点2,总距离为left[2] + right[2] + add + up[1] + right[1] + 1 + addRight
, up[2] = up[1] + right[1] + addRight
。之所以在公式的末尾有一个1
是因为从当前节点到他的父节点有一条边,所以我们需要添加1
。现在,我表示当前节点的额外距离为add
,如果父节点中有左子树,则额外距离为addLeft
,类似地,右子树为addRight
。
节点3总距离为up[1] + left[1] + 1 + addLeft
, up[3] = up[1] + left[1] + addLeft
;
节点4总距离为up[2] + right[2] + 1 + addRight
, up[4] = up[2] + right[2] + addRight
;
因此,根据当前节点是左节点还是右节点,我们相应地更新up
。
时间复杂度为O(n)
是的,你可以用DP在O(n)中求出整棵树到每两个节点的距离和。简单地说,你应该知道三件事:
cnt[i] is the node count of the ith-node's sub-tree
dis[i] is the sum distance of every ith-node subtree's node to i-th node
ret[i] is the sum distance of the ith-node subtree between every two node
注意ret[root]
是问题的答案,所以只要正确计算ret[i]
,问题就解决了…如何计算ret[i]
?需要cnt[i]
和dis[i]
的帮助,递归求解。关键问题是:
给定ret[左]ret[右]dis[左]dis[右]cnt[左]cnt[右]调用ret[node] dis[node] cnt[node]
(node)
/
(left-subtree) (right subtree)
/
...(node x_i) ... ...(node y_i)...
important:x_i is the any node in left-subtree(not leaf!)
and y_i is the any node in right-subtree(not leaf either!).
cnt[node]
很简单,只等于cnt[left] + cnt[right] + 1
dis[node]
不是那么难,等于dis[left] + dis[right] + cnt[left] + cnt[right]
。原因:sigma(xi->left)是dis[left]
,所以sigma(xi->node)是dis[left] + cnt[left]
。
ret[node]
=三部分:
- x <子>子> j -> x <子>子>和y <子>子> -> y <子>子>,等于
ret[left] + ret[right]
。 - xi ->节点和yi ->节点,等于
dis[node]
。 - xi -> yj:
=σ(x <子>子>>节点-> y <子>子>),固定x <子> ,那么我们会问(左)*距离(x <子> ,节点)+σy(节点-> <子>子>),然后问(左)*距离(x <子> ,节点)+σ(节点->左-> y <子>子>),子>子>子>
是cnt[left]*distance(x_i,node) + cnt[left] + dis[left]
。
xi: cnt[left]*(cnt[right]+dis[right]) + cnt[right]*(cnt[left] + dis[left])
,则为2*cnt[left]*cnt[right] + dis[left]*cnt[right] + dis[right]*cnt[left]
。
将这三个部分相加,我们得到ret[i]
。递归地执行,我们将得到ret[root]
。
import java.util.Arrays;
public class BSTDistance {
int[] left;
int[] right;
int[] cnt;
int[] ret;
int[] dis;
int nNode;
public BSTDistance(int n) {// n is the number of node
left = new int[n];
right = new int[n];
cnt = new int[n];
ret = new int[n];
dis = new int[n];
Arrays.fill(left,-1);
Arrays.fill(right,-1);
nNode = n;
}
void add(int a, int b)
{
if (left[b] == -1)
{
left[b] = a;
}
else
{
right[b] = a;
}
}
int cal()
{
_cal(0);//assume root's idx is 0
return ret[0];
}
void _cal(int idx)
{
if (left[idx] == -1 && right[idx] == -1)
{
cnt[idx] = 1;
dis[idx] = 0;
ret[idx] = 0;
}
else if (left[idx] != -1 && right[idx] == -1)
{
_cal(left[idx]);
cnt[idx] = cnt[left[idx]] + 1;
dis[idx] = dis[left[idx]] + cnt[left[idx]];
ret[idx] = ret[left[idx]] + dis[idx];
}//left[idx] == -1 and right[idx] != -1 is impossible, guarranted by add(int,int)
else
{
_cal(left[idx]);
_cal(right[idx]);
cnt[idx] = cnt[left[idx]] + 1 + cnt[right[idx]];
dis[idx] = dis[left[idx]] + dis[right[idx]] + cnt[left[idx]] + cnt[right[idx]];
ret[idx] = dis[idx] + ret[left[idx]] + ret[right[idx]] + 2*cnt[left[idx]]*cnt[right[idx]] + dis[left[idx]]*cnt[right[idx]] + dis[right[idx]]*cnt[left[idx]];
}
}
public static void main(String[] args)
{
BSTDistance bst1 = new BSTDistance(3);
bst1.add(1, 0);
bst1.add(2, 0);
// (0)
// /
//(1) (2)
System.out.println(bst1.cal());
BSTDistance bst2 = new BSTDistance(5);
bst2.add(1, 0);
bst2.add(2, 0);
bst2.add(3, 1);
bst2.add(4, 1);
// (0)
// /
// (1) (2)
// /
// (3) (4)
//0 -> 1:1
//0 -> 2:1
//0 -> 3:2
//0 -> 4:2
//1 -> 2:2
//1 -> 3:1
//1 -> 4:1
//2 -> 3:3
//2 -> 4:3
//3 -> 4:2
//2*4+3*2+1*4=18
System.out.println(bst2.cal());
}
}
输出:4
18
为了方便(读者理解我的解决方案),我将cnt[],dis[] and ret[]
的值粘贴在bst2.cal()
之后:
cnt[] 5 3 1 1 1
dis[] 6 2 0 0 0
ret[] 18 4 0 0 0
PS:这是UESTC_elfness的解决方案,这对他来说是一个简单的问题,而我说,天哪,这个问题对我来说并不难。
所以你可以相信我们…
首先,为每个节点添加四个变量。四个变量分别是到左子代的距离和、到右子代的距离和、左子代的节点数和右子代的节点数。分别记为l、r、nl和nr。
其次,在根节点上添加一个total变量,记录每次插入后的总和。
这个想法是,如果你有当前树的总数,插入新节点后的新总数是(旧总数+新节点到所有其他节点的距离总和)。每次插入需要计算的是新节点到所有其他节点的距离之和。1- Insert the new node with four variable set to zero.
2- Create two temp counter "node travel" and "subtotal" with value zero.
3- Back trace the route from new node to root.
a- go up to parent node
b- add one to node travel
c- add node travel to subtotal
d- add (nr * node travel) + r to subtotal if the new node is on left offspring
e- add node travel to l
f- add one to nl
4- Add subtotal to total
1 - O(n)
2 - O(1)
3 - O(log n), a到f取O(1)
4 - 0 (1)
如果您的意思是每次插入O(n),那么这是可以做到的,假设您在每次插入之后都从根开始。
0- Record the current sum of the distances. Call it s1: O(1).
1- Insert the new node: O(n).
2- Perform a BFS, starting at this new node.
For each new node you discover, record its distance to the start (new) node, as BFS always does: O(n).
This gives you an array of the distances from the start node to all other nodes.
3- Sum these distances up. Call this s2: O(n).
4- New_sum = s1 + s2: O(1).
因此该算法为O(n)。