📌  相关文章
📜  查询二叉树的两个节点之间的距离 – O(logn) 方法

📅  最后修改于: 2021-10-27 09:04:32             🧑  作者: Mango

给定一棵二叉树,任务是找到二叉树中两个键之间的距离,没有给出父指针。两个节点之间的距离是从另一个节点到达一个节点所要遍历的最小边数。
这个问题在之前的文章中已经讨论过,但它使用了二叉树的三个遍历,一个用于查找两个节点(让 A 和 B)的最低公共祖先(LCA),然后两个遍历用于查找 LCA 与 A 和 LCA 之间的距离和 B 的时间复杂度为 O(n)。在这篇文章中,将讨论一种需要O(log(n))时间来查找两个节点的 LCA 的方法。

可以根据最低共同祖先获得两个节点之间的距离。以下是公式。

Dist(n1, n2) = Dist(root, n1) + Dist(root, n2) - 2*Dist(root, lca) 
'n1' and 'n2' are the two given keys
'root' is root of given Binary Tree.
'lca' is lowest common ancestor of n1 and n2
Dist(n1, n2) is the distance between n1 and n2.

上面的公式也可以写成:

Dist(n1, n2) = Level[n1] + Level[n2] - 2*Level[lca] 

这个问题可以分解为:

  1. 查找每个节点的级别
  2. 寻找二叉树的欧拉之旅
  3. 为 LCA 构建段树,

这些步骤解释如下:

C++
// C++ program to find distance between
// two nodes for multiple queries
#include 
#define MAX 100001
using namespace std;
 
/* A tree node structure */
struct Node {
    int data;
    struct Node* left;
    struct Node* right;
};
 
/* Utility function to create a new Binary Tree node */
struct Node* newNode(int data)
{
    struct Node* temp = new struct Node;
    temp->data = data;
    temp->left = temp->right = NULL;
    return temp;
}
 
// Array to store level of each node
int level[MAX];
 
// Utility Function to store level of all nodes
void FindLevels(struct Node* root)
{
    if (!root)
        return;
 
    // queue to hold tree node with level
    queue > q;
 
    // let root node be at level 0
    q.push({ root, 0 });
 
    pair p;
 
    // Do level Order Traversal of tree
    while (!q.empty()) {
        p = q.front();
        q.pop();
 
        // Node p.first is on level p.second
        level[p.first->data] = p.second;
 
        // If left child exits, put it in queue
        // with current_level +1
        if (p.first->left)
            q.push({ p.first->left, p.second + 1 });
 
        // If right child exists, put it in queue
        // with current_level +1
        if (p.first->right)
            q.push({ p.first->right, p.second + 1 });
    }
}
 
// Stores Euler Tour
int Euler[MAX];
 
// index in Euler array
int idx = 0;
 
// Find Euler Tour
void eulerTree(struct Node* root)
{
 
    // store current node's data
    Euler[++idx] = root->data;
 
    // If left node exists
    if (root->left) {
 
        // traverse left subtree
        eulerTree(root->left);
 
        // store parent node's data
        Euler[++idx] = root->data;
    }
 
    // If right node exists
    if (root->right) {
        // traverse right subtree
        eulerTree(root->right);
 
        // store parent node's data
        Euler[++idx] = root->data;
    }
}
 
// checks for visited nodes
int vis[MAX];
 
// Stores level of Euler Tour
int L[MAX];
 
// Stores indices of first occurrence
// of nodes in Euler tour
int H[MAX];
 
// Preprocessing Euler Tour for finding LCA
void preprocessEuler(int size)
{
    for (int i = 1; i <= size; i++) {
        L[i] = level[Euler[i]];
 
        // If node is not visited before
        if (vis[Euler[i]] == 0) {
            // Add to first occurrence
            H[Euler[i]] = i;
 
            // Mark it visited
            vis[Euler[i]] = 1;
        }
    }
}
 
// Stores values and positions
pair seg[4 * MAX];
 
// Utility function to find minimum of
// pair type values
pair min(pair a,
                   pair b)
{
    if (a.first <= b.first)
        return a;
    else
        return b;
}
 
// Utility function to build segment tree
pair buildSegTree(int low, int high, int pos)
{
    if (low == high) {
        seg[pos].first = L[low];
        seg[pos].second = low;
        return seg[pos];
    }
    int mid = low + (high - low) / 2;
    buildSegTree(low, mid, 2 * pos);
    buildSegTree(mid + 1, high, 2 * pos + 1);
 
    seg[pos] = min(seg[2 * pos], seg[2 * pos + 1]);
}
 
// Utility function to find LCA
pair LCA(int qlow, int qhigh, int low,
                   int high, int pos)
{
    if (qlow <= low && qhigh >= high)
        return seg[pos];
 
    if (qlow > high || qhigh < low)
        return { INT_MAX, 0 };
 
    int mid = low + (high - low) / 2;
 
    return min(LCA(qlow, qhigh, low, mid, 2 * pos),
               LCA(qlow, qhigh, mid + 1, high, 2 * pos + 1));
}
 
// Function to return distance between
// two nodes n1 and n2
int findDistance(int n1, int n2, int size)
{
    // Maintain original Values
    int prevn1 = n1, prevn2 = n2;
 
    // Get First Occurrence of n1
    n1 = H[n1];
 
    // Get First Occurrence of n2
    n2 = H[n2];
 
    // Swap if low > high
    if (n2 < n1)
        swap(n1, n2);
 
    // Get position of minimum value
    int lca = LCA(n1, n2, 1, size, 1).second;
 
    // Extract value out of Euler tour
    lca = Euler[lca];
 
    // return calculated distance
    return level[prevn1] + level[prevn2] - 2 * level[lca];
}
 
void preProcessing(Node* root, int N)
{
    // Build Tree
    eulerTree(root);
 
    // Store Levels
    FindLevels(root);
 
    // Find L and H array
    preprocessEuler(2 * N - 1);
 
    // Build segment Tree
    buildSegTree(1, 2 * N - 1, 1);
}
 
/* Driver function to test above functions */
int main()
{
    int N = 8; // Number of nodes
 
    /* Constructing tree given in the above figure */
    Node* root = newNode(1);
    root->left = newNode(2);
    root->right = newNode(3);
    root->left->left = newNode(4);
    root->left->right = newNode(5);
    root->right->left = newNode(6);
    root->right->right = newNode(7);
    root->right->left->right = newNode(8);
 
    // Function to do all preprocessing
    preProcessing(root, N);
 
    cout << "Dist(4, 5) = " <<
      findDistance(4, 5, 2 * N - 1) << "\n";
    cout << "Dist(4, 6) = " <<
      findDistance(4, 6, 2 * N - 1) << "\n";
    cout << "Dist(3, 4) = " <<
      findDistance(3, 4, 2 * N - 1) << "\n";
    cout << "Dist(2, 4) = " <<
      findDistance(2, 4, 2 * N - 1) << "\n";
    cout << "Dist(8, 5) = " <<
      findDistance(8, 5, 2 * N - 1) << "\n";
 
    return 0;
}


Java
// Java program to find distance between
// two nodes for multiple queries
import java.io.*;
import java.util.*;
 
class GFG
{
    static int MAX = 100001;
 
    /* A tree node structure */
    static class Node
    {
        int data;
        Node left, right;
 
        Node(int data)
        {
            this.data = data;
            this.left = this.right = null;
        }
    }
 
    static class Pair
    {
        T first;
        V second;
 
        Pair() {
        }
 
        Pair(T first, V second)
        {
            this.first = first;
            this.second = second;
        }
    }
 
    // Array to store level of each node
    static int[] level = new int[MAX];
 
    // Utility Function to store level of all nodes
    static void findLevels(Node root)
    {
        if (root == null)
            return;
 
        // queue to hold tree node with level
        Queue> q = new LinkedList<>();
 
        // let root node be at level 0
        q.add(new Pair(root, 0));
 
        Pair p = new Pair();
 
        // Do level Order Traversal of tree
        while (!q.isEmpty())
        {
            p = q.poll();
 
            // Node p.first is on level p.second
            level[p.first.data] = p.second;
 
            // If left child exits, put it in queue
            // with current_level +1
            if (p.first.left != null)
                q.add(new Pair(p.first.left,
                                  p.second + 1));
 
            // If right child exists, put it in queue
            // with current_level +1
            if (p.first.right != null)
                q.add(new Pair(p.first.right,
                                p.second + 1));
        }
    }
 
    // Stores Euler Tour
    static int[] Euler = new int[MAX];
 
    // index in Euler array
    static int idx = 0;
 
    // Find Euler Tour
    static void eulerTree(Node root)
    {
 
        // store current node's data
        Euler[++idx] = root.data;
 
        // If left node exists
        if (root.left != null)
        {
 
            // traverse left subtree
            eulerTree(root.left);
 
            // store parent node's data
            Euler[++idx] = root.data;
        }
 
        // If right node exists
        if (root.right != null)
        {
            // traverse right subtree
            eulerTree(root.right);
 
            // store parent node's data
            Euler[++idx] = root.data;
        }
    }
 
    // checks for visited nodes
    static int[] vis = new int[MAX];
 
    // Stores level of Euler Tour
    static int[] L = new int[MAX];
 
    // Stores indices of first occurrence
    // of nodes in Euler tour
    static int[] H = new int[MAX];
 
    // Preprocessing Euler Tour for finding LCA
    static void preprocessEuler(int size)
    {
        for (int i = 1; i <= size; i++)
        {
            L[i] = level[Euler[i]];
 
            // If node is not visited before
            if (vis[Euler[i]] == 0)
            {
 
                // Add to first occurrence
                H[Euler[i]] = i;
 
                // Mark it visited
                vis[Euler[i]] = 1;
            }
        }
    }
 
    // Stores values and positions
    @SuppressWarnings("unchecked")
    static Pair[] seg =
    (Pair[]) new Pair[4 * MAX];
 
    // Utility function to find minimum of
    // pair type values
    static Pair
                       min(Pair a,
                           Pair b)
    {
        if (a.first <= b.first)
            return a;
        return b;
    }
 
    // Utility function to build segment tree
    static Pair buildSegTree(int low,
                                    int high, int pos)
    {
        if (low == high)
        {
            seg[pos].first = L[low];
            seg[pos].second = low;
            return seg[pos];
        }
        int mid = low + (high - low) / 2;
        buildSegTree(low, mid, 2 * pos);
        buildSegTree(mid + 1, high, 2 * pos + 1);
 
        seg[pos] = min(seg[2 * pos], seg[2 * pos + 1]);
 
        return seg[pos];
    }
 
    // Utility function to find LCA
    static Pair LCA(int qlow, int qhigh,
                                int low, int high, int pos)
    {
        if (qlow <= low && qhigh >= high)
            return seg[pos];
 
        if (qlow > high || qhigh < low)
            return new Pair
                                  (Integer.MAX_VALUE, 0);
 
        int mid = low + (high - low) / 2;
 
        return min(LCA(qlow, qhigh, low, mid, 2 * pos),
           LCA(qlow, qhigh, mid + 1, high, 2 * pos + 1));
    }
 
    // Function to return distance between
    // two nodes n1 and n2
    static int findDistance(int n1, int n2, int size)
    {
 
        // Maintain original Values
        int prevn1 = n1, prevn2 = n2;
 
        // Get First Occurrence of n1
        n1 = H[n1];
 
        // Get First Occurrence of n2
        n2 = H[n2];
 
        // Swap if low > high
        if (n2 < n1)
        {
            int temp = n1;
            n1 = n2;
            n2 = temp;
        }
 
        // Get position of minimum value
        int lca = LCA(n1, n2, 1, size, 1).second;
 
        // Extract value out of Euler tour
        lca = Euler[lca];
 
        // return calculated distance
        return level[prevn1] + level[prevn2] -
                                  2 * level[lca];
    }
 
    static void preProcessing(Node root, int N)
    {
        for (int i = 0; i < 4 * MAX; i++)
        {
            seg[i] = new Pair<>();
        }
 
        // Build Tree
        eulerTree(root);
 
        // Store Levels
        findLevels(root);
 
        // Find L and H array
        preprocessEuler(2 * N - 1);
 
        // Build segment Tree
        buildSegTree(1, 2 * N - 1, 1);
    }
 
    // Driver Code
    public static void main(String[] args)
    {
 
        // Number of nodes
        int N = 8;
 
        /* Constructing tree given in the above figure */
        Node root = new Node(1);
        root.left = new Node(2);
        root.right = new Node(3);
        root.left.left = new Node(4);
        root.left.right = new Node(5);
        root.right.left = new Node(6);
        root.right.right = new Node(7);
        root.right.left.right = new Node(8);
 
        // Function to do all preprocessing
        preProcessing(root, N);
 
        System.out.println("Dist(4, 5) = " +
                        findDistance(4, 5, 2 * N - 1));
        System.out.println("Dist(4, 6) = " +
                        findDistance(4, 6, 2 * N - 1));
        System.out.println("Dist(3, 4) = " +
                        findDistance(3, 4, 2 * N - 1));
        System.out.println("Dist(2, 4) = " +
                        findDistance(2, 4, 2 * N - 1));
        System.out.println("Dist(8, 5) = " +
                        findDistance(8, 5, 2 * N - 1));
    }
}
 
// This code is contributed by
// sanjeev2552


Python3
# Python3 program to find distance between
# two nodes for multiple queries
 
from collections import deque
from sys import maxsize as INT_MAX
 
MAX = 100001
 
# A tree node structure
class Node:
    def __init__(self, data):
        self.data = data
        self.left = None
        self.right = None
 
# Array to store level of each node
level = [0] * MAX
 
# Utility Function to store level of all nodes
def findLevels(root: Node):
    global level
 
    if root is None:
        return
 
    # queue to hold tree node with level
    q = deque()
 
    # let root node be at level 0
    q.append((root, 0))
 
    # Do level Order Traversal of tree
    while q:
        p = q[0]
        q.popleft()
 
        # Node p.first is on level p.second
        level[p[0].data] = p[1]
 
        # If left child exits, put it in queue
        # with current_level +1
        if p[0].left:
            q.append((p[0].left, p[1] + 1))
 
        # If right child exists, put it in queue
        # with current_level +1
        if p[0].right:
            q.append((p[0].right, p[1] + 1))
 
# Stores Euler Tour
Euler = [0] * MAX
 
# index in Euler array
idx = 0
 
# Find Euler Tour
def eulerTree(root: Node):
    global Euler, idx
    idx += 1
 
    # store current node's data
    Euler[idx] = root.data
 
    # If left node exists
    if root.left:
 
        # traverse left subtree
        eulerTree(root.left)
        idx += 1
 
        # store parent node's data
        Euler[idx] = root.data
 
    # If right node exists
    if root.right:
 
        # traverse right subtree
        eulerTree(root.right)
        idx += 1
 
        # store parent node's data
        Euler[idx] = root.data
 
# checks for visited nodes
vis = [0] * MAX
 
# Stores level of Euler Tour
L = [0] * MAX
 
# Stores indices of the first occurrence
# of nodes in Euler tour
H = [0] * MAX
 
# Preprocessing Euler Tour for finding LCA
def preprocessEuler(size: int):
    global L, H, vis
    for i in range(1, size + 1):
        L[i] = level[Euler[i]]
 
        # If node is not visited before
        if vis[Euler[i]] == 0:
 
            # Add to first occurrence
            H[Euler[i]] = i
 
            # Mark it visited
            vis[Euler[i]] = 1
 
# Stores values and positions
seg = [0] * (4 * MAX)
for i in range(4 * MAX):
    seg[i] = [0, 0]
 
# Utility function to find minimum of
# pair type values
def minPair(a: list, b: list) -> list:
    if a[0] <= b[0]:
        return a
    else:
        return b
 
# Utility function to build segment tree
def buildSegTree(low: int, high: int,
                       pos: int) -> list:
    if low == high:
        seg[pos][0] = L[low]
        seg[pos][1] = low
        return seg[pos]
 
    mid = low + (high - low) // 2
    buildSegTree(low, mid, 2 * pos)
    buildSegTree(mid + 1, high, 2 * pos + 1)
 
    seg[pos] = min(seg[2 * pos], seg[2 * pos + 1])
 
# Utility function to find LCA
def LCA(qlow: int, qhigh: int, low: int,
                     high: int, pos: int) -> list:
    if qlow <= low and qhigh >= high:
        return seg[pos]
 
    if qlow > high or qhigh < low:
        return [INT_MAX, 0]
 
    mid = low + (high - low) // 2
 
    return minPair(LCA(qlow, qhigh, low, mid, 2 * pos),
          LCA(qlow, qhigh, mid + 1, high, 2 * pos + 1))
 
# Function to return distance between
# two nodes n1 and n2
def findDistance(n1: int, n2: int, size: int) -> int:
 
    # Maintain original Values
    prevn1 = n1
    prevn2 = n2
 
    # Get First Occurrence of n1
    n1 = H[n1]
 
    # Get First Occurrence of n2
    n2 = H[n2]
 
    # Swap if low>high
    if n2 < n1:
        n1, n2 = n2, n1
 
    # Get position of minimum value
    lca = LCA(n1, n2, 1, size, 1)[1]
 
    # Extract value out of Euler tour
    lca = Euler[lca]
 
    # return calculated distance
    return level[prevn1] + level[prevn2] -
                                2 * level[lca]
 
def preProcessing(root: Node, N: int):
 
    # Build Tree
    eulerTree(root)
 
    # Store Levels
    findLevels(root)
 
    # Find L and H array
    preprocessEuler(2 * N - 1)
 
    # Build sparse table
    buildSegTree(1, 2 * N - 1, 1)
 
# Driver Code
if __name__ == "__main__":
 
    # Number of nodes
    N = 8
 
    # Constructing tree given in the above figure
    root = Node(1)
    root.left = Node(2)
    root.right = Node(3)
    root.left.left = Node(4)
    root.left.right = Node(5)
    root.right.left = Node(6)
    root.right.right = Node(7)
    root.right.left.right = Node(8)
 
    # Function to do all preprocessing
    preProcessing(root, N)
 
    print("Dist(4, 5) =",
          findDistance(4, 5, 2 * N - 1))
    print("Dist(4, 6) =",
          findDistance(4, 6, 2 * N - 1))
    print("Dist(3, 4) =",
          findDistance(3, 4, 2 * N - 1))
    print("Dist(2, 4) =",
          findDistance(2, 4, 2 * N - 1))
    print("Dist(8, 5) =",
          findDistance(8, 5, 2 * N - 1))
 
# This code is contributed by
# sanjeev2552


C#
// C# program to find distance between
// two nodes for multiple queries
using System;
using System.Collections.Generic;
 
class GFG
{
    static int MAX = 100001;
 
    /* A tree node structure */
    public class Node
    {
        public int data;
        public Node left, right;
 
        public Node(int data)
        {
            this.data = data;
            this.left = this.right = null;
        }
    }
 
    class Pair
    {
        public T first;
        public V second;
 
        public Pair() {
        }
 
        public Pair(T first, V second)
        {
            this.first = first;
            this.second = second;
        }
    }
 
    // Array to store level of each node
    static int[] level = new int[MAX];
 
    // Utility Function to store level of all nodes
    static void findLevels(Node root)
    {
        if (root == null)
            return;
 
        // queue to hold tree node with level
        List> q =
        new List>();
 
        // let root node be at level 0
        q.Add(new Pair(root, 0));
 
        Pair p = new Pair();
 
        // Do level Order Traversal of tree
        while (q.Count != 0)
        {
            p = q[0];
            q.RemoveAt(0);
 
            // Node p.first is on level p.second
            level[p.first.data] = p.second;
 
            // If left child exits, put it in queue
            // with current_level +1
            if (p.first.left != null)
                q.Add(new Pair
                      (p.first.left, p.second + 1));
                                           
 
            // If right child exists, put it in queue
            // with current_level +1
            if (p.first.right != null)
                q.Add(new Pair
                      (p.first.right, p.second + 1));
        }
    }
 
    // Stores Euler Tour
    static int[] Euler = new int[MAX];
 
    // index in Euler array
    static int idx = 0;
 
    // Find Euler Tour
    static void eulerTree(Node root)
    {
 
        // store current node's data
        Euler[++idx] = root.data;
 
        // If left node exists
        if (root.left != null)
        {
 
            // traverse left subtree
            eulerTree(root.left);
 
            // store parent node's data
            Euler[++idx] = root.data;
        }
 
        // If right node exists
        if (root.right != null)
        {
            // traverse right subtree
            eulerTree(root.right);
 
            // store parent node's data
            Euler[++idx] = root.data;
        }
    }
 
    // checks for visited nodes
    static int[] vis = new int[MAX];
 
    // Stores level of Euler Tour
    static int[] L = new int[MAX];
 
    // Stores indices of first occurrence
    // of nodes in Euler tour
    static int[] H = new int[MAX];
 
    // Preprocessing Euler Tour for finding LCA
    static void preprocessEuler(int size)
    {
        for (int i = 1; i <= size; i++)
        {
            L[i] = level[Euler[i]];
 
            // If node is not visited before
            if (vis[Euler[i]] == 0)
            {
 
                // Add to first occurrence
                H[Euler[i]] = i;
 
                // Mark it visited
                vis[Euler[i]] = 1;
            }
        }
    }
 
    // Stores values and positions
    static Pair[] seg = new
                          Pair[4 * MAX];
 
    // Utility function to find minimum of
    // pair type values
    static Pair min(Pair a,
                                    Pair b)
    {
        if (a.first <= b.first)
            return a;
        return b;
    }
 
    // Utility function to build segment tree
    static Pair buildSegTree(int low,
                                    int high, int pos)
    {
        if (low == high)
        {
            seg[pos].first = L[low];
            seg[pos].second = low;
            return seg[pos];
        }
        int mid = low + (high - low) / 2;
        buildSegTree(low, mid, 2 * pos);
        buildSegTree(mid + 1, high, 2 * pos + 1);
 
        seg[pos] = min(seg[2 * pos], seg[2 * pos + 1]);
 
        return seg[pos];
    }
 
    // Utility function to find LCA
    static Pair LCA(int qlow, int qhigh,
                    int low, int high, int pos)
    {
        if (qlow <= low && qhigh >= high)
            return seg[pos];
 
        if (qlow > high || qhigh < low)
            return new Pair(int.MaxValue, 0);
 
        int mid = low + (high - low) / 2;
 
        return min(LCA(qlow, qhigh, low, mid, 2 * pos),
                LCA(qlow, qhigh, mid + 1,
                                 high, 2 * pos + 1));
    }
 
    // Function to return distance between
    // two nodes n1 and n2
    static int findDistance(int n1, int n2, int size)
    {
 
        // Maintain original Values
        int prevn1 = n1, prevn2 = n2;
 
        // Get First Occurrence of n1
        n1 = H[n1];
 
        // Get First Occurrence of n2
        n2 = H[n2];
 
        // Swap if low > high
        if (n2 < n1)
        {
            int temp = n1;
            n1 = n2;
            n2 = temp;
        }
 
        // Get position of minimum value
        int lca = LCA(n1, n2, 1, size, 1).second;
 
        // Extract value out of Euler tour
        lca = Euler[lca];
 
        // return calculated distance
        return level[prevn1] + level[prevn2] -
                                2 * level[lca];
    }
 
    static void preProcessing(Node root, int N)
    {
        for (int i = 0; i < 4 * MAX; i++)
        {
            seg[i] = new Pair();
        }
 
        // Build Tree
        eulerTree(root);
 
        // Store Levels
        findLevels(root);
 
        // Find L and H array
        preprocessEuler(2 * N - 1);
 
        // Build segment Tree
        buildSegTree(1, 2 * N - 1, 1);
    }
 
    // Driver Code
    public static void Main(String[] args)
    {
 
        // Number of nodes
        int N = 8;
 
        /* Constructing tree given in the above figure */
        Node root = new Node(1);
        root.left = new Node(2);
        root.right = new Node(3);
        root.left.left = new Node(4);
        root.left.right = new Node(5);
        root.right.left = new Node(6);
        root.right.right = new Node(7);
        root.right.left.right = new Node(8);
 
        // Function to do all preprocessing
        preProcessing(root, N);
 
        Console.WriteLine("Dist(4, 5) = " +
                       findDistance(4, 5, 2 * N - 1));
        Console.WriteLine("Dist(4, 6) = " +
                       findDistance(4, 6, 2 * N - 1));
        Console.WriteLine("Dist(3, 4) = " +
                       findDistance(3, 4, 2 * N - 1));
        Console.WriteLine("Dist(2, 4) = " +
                       findDistance(2, 4, 2 * N - 1));
        Console.WriteLine("Dist(8, 5) = " +
                       findDistance(8, 5, 2 * N - 1));
    }
}
 
// This code is contributed by Rajput-Ji


输出

Dist(4, 5) = 2
Dist(4, 6) = 4
Dist(3, 4) = 3
Dist(2, 4) = 1
Dist(8, 5) = 5

时间复杂度: O(Log N)
空间复杂度: O(N)
查询二叉树的两个节点之间的距离 – O(1) 方法