算法-二叉搜索树的每两个节点在O(n)内的距离之和



问题是在每个父子对被单位距离分隔的情况下,找出BinarySearchTree中每两个节点之间的距离之和。每次插入后都要计算。

,

 ->first node is inserted..
      (root)
   total sum=0;
->left and right node are inserted
      (root)
      /    
  (left)   (right)
   total sum = distance(root,left)+distance(root,right)+distance(left,right);
             =        1           +         1          +         2
             =     4
and so on.....

我想到的解决方案:

  1. 蛮力。步骤:

    1. 执行DFS并跟踪所有节点:O(n) .
    2. 选取每两个节点,用最低共同祖先法计算它们之间的O(nC2)_times_O(log(n))=O(n2log(n))距离并相加。

    总体复杂度:-O(n2log(n))

  2. O(nlog(n))。步骤:-

    1. 插入前执行DFS并跟踪所有节点:O(n) .
    2. 计算插入节点到:O(nlog(n))的距离。
    3. 将现有的总和与步骤2中计算的总和相加

    总体复杂度:-O(nlog(n)) .

现在的问题是"是否存在O(n)阶的解??"

我们可以通过遍历树两次来实现。

首先,我们需要三个数组

int []left,其中存储了左子树的距离和。

int []right,其中存储了右子树的距离和。

int []up,存储父树(不含当前子树)的距离之和。

那么,第一次遍历,对于每个节点,我们计算左和右距离。如果节点是叶节点,则返回0,如果不是,则可以得到如下公式:

int cal(Node node){
    int left = cal(node.left);
    int right = cal(node.right);
    left[node.index] = left;
    right[node.index] = right;
    //Depend on the current node have left or right node, we add 0,1 or 2 to the final result
    int add = (node.left != null && node.right != null)? 2 : node.left != null ? 1 : node.right != null ? 1 : 0;
    return left + right + add;
}

对于第二次遍历,我们需要在每个节点上加上到父节点的总距离。

             1
            / 
           2   3
          / 
         4   5

例如节点1(根),总距离为left[1] + right[1] + 2, up[1] = 0;(我们加上2,因为根有左右子树,确切的公式是:

int add = 0; 
if (node.left != null) 
    add++;
if(node.right != null)
    add++;

对于节点2,总距离为left[2] + right[2] + add + up[1] + right[1] + 1 + addRight, up[2] = up[1] + right[1] + addRight。之所以在公式的末尾有一个1是因为从当前节点到他的父节点有一条边,所以我们需要添加1。现在,我表示当前节点的额外距离为add,如果父节点中有左子树,则额外距离为addLeft,类似地,右子树为addRight

节点3总距离为up[1] + left[1] + 1 + addLeft, up[3] = up[1] + left[1] + addLeft;

节点4总距离为up[2] + right[2] + 1 + addRight, up[4] = up[2] + right[2] + addRight;

因此,根据当前节点是左节点还是右节点,我们相应地更新up

时间复杂度为O(n)

是的,你可以用DP在O(n)中求出整棵树到每两个节点的距离和。简单地说,你应该知道三件事:

cnt[i] is the node count of the ith-node's sub-tree
dis[i] is the sum distance of every ith-node subtree's node to i-th node
ret[i] is the sum distance of the ith-node subtree between every two node

注意ret[root]是问题的答案,所以只要正确计算ret[i],问题就解决了…如何计算ret[i] ?需要cnt[i]dis[i]的帮助,递归求解。关键问题是:

给定ret[左]ret[右]dis[左]dis[右]cnt[左]cnt[右]调用ret[node] dis[node] cnt[node]

              (node)
          /             
    (left-subtree) (right subtree)
      /                   
...(node x_i) ...   ...(node y_i)...
important:x_i is the any node in left-subtree(not leaf!) 
and y_i is the any node in right-subtree(not leaf either!).

cnt[node]很简单,只等于cnt[left] + cnt[right] + 1

dis[node]不是那么难,等于dis[left] + dis[right] + cnt[left] + cnt[right]。原因:sigma(xi->left)是dis[left],所以sigma(xi->node)是dis[left] + cnt[left]

ret[node] =三部分:

  1. x <子> j -> x <子>和y <子> -> y <子>,等于ret[left] + ret[right]
  2. xi ->节点和yi ->节点,等于dis[node]
  3. xi -> yj:

=σ(x <子>>节点-> y <子>),固定x <子> ,那么我们会问(左)*距离(x <子> ,节点)+σy(节点-> <子>),然后问(左)*距离(x <子> ,节点)+σ(节点->左-> y <子>),

cnt[left]*distance(x_i,node) + cnt[left] + dis[left]

xi: cnt[left]*(cnt[right]+dis[right]) + cnt[right]*(cnt[left] + dis[left]),则为2*cnt[left]*cnt[right] + dis[left]*cnt[right] + dis[right]*cnt[left]

将这三个部分相加,我们得到ret[i]。递归地执行,我们将得到ret[root]

我代码:

import java.util.Arrays;
public class BSTDistance {
    int[] left;
    int[] right;
    int[] cnt;
    int[] ret;
    int[] dis;
    int nNode;
    public BSTDistance(int n) {// n is the number of node
        left = new int[n];
        right = new int[n];
        cnt = new int[n];
        ret = new int[n];
        dis = new int[n];
        Arrays.fill(left,-1);
        Arrays.fill(right,-1);
        nNode = n;
    }
    void add(int a, int b)
    {
        if (left[b] == -1)
        {
            left[b] = a;
        }
        else
        {
            right[b] = a;
        }
    }
    int cal()
    {
        _cal(0);//assume root's idx is 0
        return ret[0];
    }
    void _cal(int idx)
    {
        if (left[idx] == -1 && right[idx] == -1)
        {
            cnt[idx] = 1;
            dis[idx] = 0;
            ret[idx] = 0;
        }
        else if (left[idx] != -1  && right[idx] == -1)
        {
            _cal(left[idx]);
            cnt[idx] = cnt[left[idx]] + 1;
            dis[idx] = dis[left[idx]] + cnt[left[idx]];
            ret[idx] = ret[left[idx]] + dis[idx];
        }//left[idx] == -1 and right[idx] != -1 is impossible, guarranted by add(int,int)  
        else 
        {
            _cal(left[idx]);
            _cal(right[idx]);
            cnt[idx] = cnt[left[idx]] + 1 + cnt[right[idx]];
            dis[idx] = dis[left[idx]] + dis[right[idx]] + cnt[left[idx]] + cnt[right[idx]];
            ret[idx] = dis[idx] + ret[left[idx]] + ret[right[idx]] + 2*cnt[left[idx]]*cnt[right[idx]] + dis[left[idx]]*cnt[right[idx]] + dis[right[idx]]*cnt[left[idx]];
        }
    }
    public static void main(String[] args)
    {
        BSTDistance bst1 = new BSTDistance(3);
        bst1.add(1, 0);
        bst1.add(2, 0);
        //   (0)
        //  /   
        //(1)   (2)
        System.out.println(bst1.cal());
        BSTDistance bst2 = new BSTDistance(5);
        bst2.add(1, 0);
        bst2.add(2, 0);
        bst2.add(3, 1);
        bst2.add(4, 1);
        //       (0)
        //      /   
        //    (1)   (2)
        //   /   
        // (3)   (4)
        //0 -> 1:1
        //0 -> 2:1
        //0 -> 3:2
        //0 -> 4:2
        //1 -> 2:2
        //1 -> 3:1
        //1 -> 4:1
        //2 -> 3:3
        //2 -> 4:3
        //3 -> 4:2
        //2*4+3*2+1*4=18
        System.out.println(bst2.cal());
    }
}
输出:

4
18

为了方便(读者理解我的解决方案),我将cnt[],dis[] and ret[]的值粘贴在bst2.cal()之后:

cnt[] 5 3 1 1 1 
dis[] 6 2 0 0 0
ret[] 18 4 0 0 0 

PS:这是UESTC_elfness的解决方案,这对他来说是一个简单的问题,而我说,天哪,这个问题对我来说并不难。

所以你可以相信我们…

首先,为每个节点添加四个变量。四个变量分别是到左子代的距离和、到右子代的距离和、左子代的节点数和右子代的节点数。分别记为l、r、nl和nr。

其次,在根节点上添加一个total变量,记录每次插入后的总和。

这个想法是,如果你有当前树的总数,插入新节点后的新总数是(旧总数+新节点到所有其他节点的距离总和)。每次插入需要计算的是新节点到所有其他节点的距离之和。
1- Insert the new node with four variable set to zero.
2- Create two temp counter "node travel" and "subtotal" with value zero.
3- Back trace the route from new node to root. 
   a- go up to parent node
   b- add one to node travel 
   c- add node travel to subtotal
   d- add (nr * node travel) + r to subtotal if the new node is on left offspring
   e- add node travel to l
   f- add one to nl
4- Add subtotal to total

1 - O(n)

2 - O(1)

3 - O(log n), a到f取O(1)

4 - 0 (1)

如果您的意思是每次插入O(n),那么这是可以做到的,假设您在每次插入之后都从根开始。

0- Record the current sum of the distances. Call it s1: O(1).
1- Insert the new node: O(n).
2- Perform a BFS, starting at this new node.
   For each new node you discover, record its distance to the start (new) node, as BFS always does: O(n).
   This gives you an array of the distances from the start node to all other nodes.
3- Sum these distances up. Call this s2: O(n).
4- New_sum = s1 + s2: O(1).

因此该算法为O(n)。

最新更新