📜  欧拉之旅|使用细分树的子树总和

📅  最后修改于: 2021-04-17 08:53:09             🧑  作者: Mango

欧拉巡回树(ETT)是一种用于将根树表示为数字序列的方法。当使用“深度搜索”(DFS)遍历树时,将每个节点两次插入向量中,一次输入一次,然后在访问其所有子节点之后再次插入。此方法对于解决子树问题非常有用,其中一个问题就是子树总和。

先决条件:段树(给定范围的总和)

天真的方法:
如下图所示,考虑连接了6个顶点的有根树。将DFS应用于不同的查询。
与每个节点关联的权重写在括号内。

查询:
1.节点1的所有子树的总和。
2.将节点6的值更新为10。
3.节点2的所有子树的总和。

答案:
1. 6 + 5 + 4 + 3 + 2 +1 = 21
2。

3. 10 + 5 + 2 = 17
时间复杂度分析:
可以使用O(n)时间复杂度中的search(dfs)深度来执行此类查询。

高效的方法:
通过使用Euler巡回技术将根目录树转换为分段树,可以将此类查询的时间复杂度降低为O(log(n))时间。因此,当查询数为q时,总复杂度变为O(q * 5log(n))

欧拉之旅:
在Euler tour Technique中,每个顶点两次添加到向量中,同时下降到向量中并保留在向量中。
让我们在前面的示例的帮助下理解:

在给定的根树上使用欧拉游览技术执行搜索深度(DFS)时,形成的向量为:

s[]={1, 2, 6, 6, 5, 5, 2, 3, 4, 4, 3, 1}
C++
// DFS function to traverse the tree
int dfs(int root)
{
    s.push_back(root);
    if (v[root].size() == 0)
        return root;
 
    for (int i = 0; i & lt; v[root].size(); i++) {
        int temp = dfs(v[root][i]);
        s.push_back(temp);
    }
    return root;
}


C++
// C++ program for implementation of
// Euler Tour | Subtree Sum.
#include 
using namespace std;
 
vector v[1001];
vector s;
int seg[1001] = { 0 };
 
// Value/Weight of each node of tree,
// value of 0th(no such node) node is 0.
int ar[] = { 0, 1, 2, 3, 4, 5, 6 };
 
int vertices = 6;
int edges = 5;
 
// A recursive function that constructs
// Segment Tree for array ar[] = { }.
// 'pos' is index of current node
// in segment tree seg[].
int segment(int low, int high, int pos)
{
    if (high == low) {
        seg[pos] = ar[s[low]];
    }
    else {
        int mid = (low + high) / 2;
        segment(low, mid, 2 * pos);
        segment(mid + 1, high, 2 * pos + 1);
        seg[pos] = seg[2 * pos] + seg[2 * pos + 1];
    }
}
 
/* Return sum of elements in range
   from index l to r . It uses the
   seg[] array created using segment()
   function. 'pos' is index of current
   node in segment tree seg[].
*/
int query(int node, int start,
          int end, int l, int r)
{
    if (r < start || end < l) {
        return 0;
    }
 
    if (l <= start && end <= r) {
        return seg[node];
    }
 
    int mid = (start + end) / 2;
    int p1 = query(2 * node, start,
                   mid, l, r);
    int p2 = query(2 * node + 1, mid + 1,
                   end, l, r);
 
    return (p1 + p2);
}
 
/* A recursive function to update the
   nodes which have the given index in
   their range. The following are
   parameters pos --> index of current
   node in segment tree seg[]. idx -->
   index of the element to be updated.
   This index is in input array.
   val --> Value to be change at node idx
*/
int update(int pos, int low, int high,
           int idx, int val)
{
    if (low == high) {
        seg[pos] = val;
    }
    else {
        int mid = (low + high) / 2;
 
        if (low <= idx && idx <= mid) {
            update(2 * pos, low, mid,
                   idx, val);
        }
        else {
            update(2 * pos + 1, mid + 1,
                   high, idx, val);
        }
 
        seg[pos] = seg[2 * pos] + seg[2 * pos + 1];
    }
}
 
/* A recursive function to form array
    ar[] from a directed tree .
*/
int dfs(int root)
{
    // pushing each node in vector s
    s.push_back(root);
    if (v[root].size() == 0)
        return root;
 
    for (int i = 0; i < v[root].size(); i++) {
        int temp = dfs(v[root][i]);
        s.push_back(temp);
    }
    return root;
}
 
// Driver program to test above functions
int main()
{
    // Edges between the nodes
    v[1].push_back(2);
    v[1].push_back(3);
    v[2].push_back(6);
    v[2].push_back(5);
    v[3].push_back(4);
 
    // Calling dfs function.
    int temp = dfs(1);
    s.push_back(temp);
 
    // Storing entry time and exit
    // time of each node
    vector > p;
 
    for (int i = 0; i <= vertices; i++)
        p.push_back(make_pair(0, 0));
 
    for (int i = 0; i < s.size(); i++) {
        if (p[s[i]].first == 0)
            p[s[i]].first = i + 1;
        else
            p[s[i]].second = i + 1;
    }
 
    // Build segment tree from array ar[].
    segment(0, s.size() - 1, 1);
 
    // query of type 1 return the
    // sum of subtree at node 1.
    int node = 1;
    int e = p[node].first;
    int f = p[node].second;
 
    int ans = query(1, 1, s.size(), e, f);
 
    // print the sum of subtree
    cout << "Subtree sum of node " << node << " is : " << (ans / 2) << endl;
 
    // query of type 2 return update
    // the subtree at node 6.
    int val = 10;
    node = 6;
 
    e = p[node].first;
    f = p[node].second;
    update(1, 1, s.size(), e, val);
    update(1, 1, s.size(), f, val);
 
    // query of type 1 return the
    // sum of subtree at node 2.
    node = 2;
 
    e = p[node].first;
    f = p[node].second;
 
    ans = query(1, 1, s.size(), e, f);
 
    // print the sum of subtree
    cout << "Subtree sum of node " << node << " is : " << (ans / 2) << endl;
 
    return 0;
}


Java
// Java program for implementation of
// Euler Tour | Subtree Sum.
import java.util.ArrayList;
 
class Graph{
 
static class Pair
{
    int first, second;
     
    public Pair(int first, int second)
    {
        this.first = first;
        this.second = second;
    }
}
 
@SuppressWarnings("unchecked")
static ArrayList[] v = new ArrayList[1001];
static ArrayList s = new ArrayList<>();
static int[] seg = new int[1001];
 
// Value/Weight of each node of tree,
// value of 0th(no such node) node is 0.
static int ar[] = { 0, 1, 2, 3, 4, 5, 6 };
 
static int vertices = 6;
static int edges = 5;
 
// A recursive function that constructs
// Segment Tree for array ar[] = { }.
// 'pos' is index of current node
// in segment tree seg[].
static void segment(int low, int high, int pos)
{
    if (high == low)
    {
        seg[pos] = ar[s.get(low)];
    }
    else
    {
        int mid = (low + high) / 2;
        segment(low, mid, 2 * pos);
        segment(mid + 1, high, 2 * pos + 1);
        seg[pos] = seg[2 * pos] + seg[2 * pos + 1];
    }
}
 
// Return sum of elements in range from index
// l to r. It uses the seg[] array created
// using segment() function. 'pos' is index
// of current node in segment tree seg[].
static int query(int node, int start,
                 int end, int l, int r)
{
    if (r < start || end < l)
    {
        return 0;
    }
 
    if (l <= start && end <= r)
    {
        return seg[node];
    }
 
    int mid = (start + end) / 2;
    int p1 = query(2 * node, start,
                       mid, l, r);
    int p2 = query(2 * node + 1, mid + 1,
                       end, l, r);
 
    return (p1 + p2);
}
 
/*
 * A recursive function to update the nodes which
 * have the given index in their range.
 * The following are parameters
 * pos --> index of current node in segment tree seg[].
 * idx --> index of the element to be updated.
 *         This index is in input array.
 * val --> Value to be change at node idx
 */
static void update(int pos, int low, int high,
                   int idx, int val)
{
    if (low == high)
    {
        seg[pos] = val;
    }
    else
    {
        int mid = (low + high) / 2;
 
        if (low <= idx && idx <= mid)
        {
            update(2 * pos, low, mid,
                       idx, val);
        }
        else
        {
            update(2 * pos + 1, mid + 1,
                       high, idx, val);
        }
 
        seg[pos] = seg[2 * pos] +
                   seg[2 * pos + 1];
    }
}
 
// A recursive function to form array
// ar[] from a directed tree.
static int dfs(int root)
{
     
    // Pushing each node in ArrayList s
    s.add(root);
     
    if (v[root].size() == 0)
        return root;
 
    for(int i = 0; i < v[root].size(); i++)
    {
        int temp = dfs(v[root].get(i));
        s.add(temp);
    }
    return root;
}
 
// Driver code
public static void main(String[] args)
{
    for(int i = 0; i < 1001; i++)
    {
        v[i] = new ArrayList<>();
    }
 
    // Edges between the nodes
    v[1].add(2);
    v[1].add(3);
    v[2].add(6);
    v[2].add(5);
    v[3].add(4);
 
    // Calling dfs function.
    int temp = dfs(1);
    s.add(temp);
 
    // Storing entry time and exit
    // time of each node
    ArrayList p = new ArrayList<>();
 
    for(int i = 0; i <= vertices; i++)
        p.add(new Pair(0, 0));
 
    for(int i = 0; i < s.size(); i++)
    {
        if (p.get(s.get(i)).first == 0)
            p.get(s.get(i)).first = i + 1;
        else
            p.get(s.get(i)).second = i + 1;
    }
 
    // Build segment tree from array ar[].
    segment(0, s.size() - 1, 1);
 
    // Query of type 1 return the
    // sum of subtree at node 1.
    int node = 1;
    int e = p.get(node).first;
    int f = p.get(node).second;
 
    int ans = query(1, 1, s.size(), e, f);
 
    // Print the sum of subtree
    System.out.println("Subtree sum of node " +
                       node + " is : " + (ans / 2));
 
    // Query of type 2 return update
    // the subtree at node 6.
    int val = 10;
    node = 6;
 
    e = p.get(node).first;
    f = p.get(node).second;
    update(1, 1, s.size(), e, val);
    update(1, 1, s.size(), f, val);
 
    // Query of type 1 return the
    // sum of subtree at node 2.
    node = 2;
 
    e = p.get(node).first;
    f = p.get(node).second;
 
    ans = query(1, 1, s.size(), e, f);
 
    // Print the sum of subtree
    System.out.println("Subtree sum of node " +
                       node + " is : " + (ans / 2));
}
}
 
// This code is contributed by sanjeev2552


Python3
# Python3 program for implementation of
# Euler Tour | Subtree Sum.
v = [[] for i in range(1001)]
s = []
seg = [0]*1001
 
# Value/Weight of each node of tree,
# value of 0th(no such node) node is 0.
ar = [0, 1, 2, 3, 4, 5, 6]
 
vertices = 6
edges = 5
 
# A recursive function that constructs
# Segment Tree for array ar = .
# 'pos' is index of current node
# in segment tree seg.
def segment(low, high, pos):
    if (high == low):
        seg[pos] = ar[s[low]]
    else:
        mid = (low + high) // 2
        segment(low, mid, 2 * pos)
        segment(mid + 1, high, 2 * pos + 1)
        seg[pos] = seg[2 * pos] + seg[2 * pos + 1]
         
''' Return sum of elements in range
from index l to r . It uses the
seg array created using segment()
function. 'pos' is index of current
node in segment tree seg.
'''
def query(node, start,end, l, r):
    if (r < start or end < l):
        return 0
     
    if (l <= start and end <= r):
        return seg[node]
         
    mid = (start + end) // 2
    p1 = query(2 * node, start,mid, l, r)
    p2 = query(2 * node + 1, mid + 1, end, l, r)
     
    return (p1 + p2)
 
''' A recursive function to update the
nodes which have the given index in
their range. The following are
parameters pos --> index of current
node in segment tree seg. idx -->
index of the element to be updated.
This index is in input array.
val --> Value to be change at node idx
'''
def update(pos, low, high, idx, val):
    if (low == high):
        seg[pos] = val
    else:
        mid = (low + high) // 2
        if (low <= idx and idx <= mid):
            update(2 * pos, low, mid,idx, val)
        else:
            update(2 * pos + 1, mid + 1, high, idx, val)
         
        seg[pos] = seg[2 * pos] + seg[2 * pos + 1]
     
''' A recursive function to form array
    ar from a directed tree .
'''
def dfs(root):
     
    # pushing each node in vector s
    s.append(root)
    if (len(v[root]) == 0):
        return root
     
    for i in range(len(v[root])):
        temp = dfs(v[root][i])
        s.append(temp)
     
    return root
 
# Edges between the nodes
v[1].append(2)
v[1].append(3)
v[2].append(6)
v[2].append(5)
v[3].append(4)
 
# Calling dfs function.
temp = dfs(1)
s.append(temp)
 
# Storing entry time and exit
# time of each node
p = []
 
for i in range(vertices + 1):
    p.append([0, 0])
 
for i in range(len(s)):
    if (p[s[i]][0] == 0):
        p[s[i]][0] = i + 1
    else:
        p[s[i]][1] = i + 1
 
# Build segment tree from array ar.
segment(0, len(s) - 1, 1)
 
# query of type 1 return the
# sum of subtree at node 1.
node = 1
e = p[node][0]
f = p[node][1]
 
ans = query(1, 1, len(s), e, f)
 
# print the sum of subtree
print("Subtree sum of node", node, "is :", ans // 2)
 
# query of type 2 return update
# the subtree at node 6.
val = 10
node = 6
 
e = p[node][0]
f = p[node][1]
update(1, 1, len(s), e, val)
update(1, 1, len(s), f, val)
 
# query of type 1 return the
# sum of subtree at node 2.
node = 2
 
e = p[node][0]
f = p[node][1]
 
ans = query(1, 1, len(s), e, f)
 
# print the sum of subtree
print("Subtree sum of node",node,"is :",ans // 2)
 
# This code is contributed by SHUBHAMSINGH10


现在,使用向量s []创建Segment Tree

下面是向量s []的分段树的表示。

对于输出和更新查询,存储根树的每个节点的进入时间和退出时间(用作索引范围)。

s[]={1, 2, 6, 6, 5, 5, 2, 3, 4, 4, 3, 1}

Node  Entry time  Exit time
   1        1           12
   2        2           7
   3        8           11
   4        9           10
   5        5           6
   6        3           4

查询类型1:
在段树上找到范围总和以进行输出查询,其中范围是根树节点的退出时间和进入时间。推论答案始终是预期答案的两倍,因为每个节点在段树中添加了两次。因此,将答案减少一半。

查询类型2:
对于更新查询,请在根树节点的进入时间和退出时间更新段树的叶节点。

下面是上述方法的实现:

C++

// C++ program for implementation of
// Euler Tour | Subtree Sum.
#include 
using namespace std;
 
vector v[1001];
vector s;
int seg[1001] = { 0 };
 
// Value/Weight of each node of tree,
// value of 0th(no such node) node is 0.
int ar[] = { 0, 1, 2, 3, 4, 5, 6 };
 
int vertices = 6;
int edges = 5;
 
// A recursive function that constructs
// Segment Tree for array ar[] = { }.
// 'pos' is index of current node
// in segment tree seg[].
int segment(int low, int high, int pos)
{
    if (high == low) {
        seg[pos] = ar[s[low]];
    }
    else {
        int mid = (low + high) / 2;
        segment(low, mid, 2 * pos);
        segment(mid + 1, high, 2 * pos + 1);
        seg[pos] = seg[2 * pos] + seg[2 * pos + 1];
    }
}
 
/* Return sum of elements in range
   from index l to r . It uses the
   seg[] array created using segment()
   function. 'pos' is index of current
   node in segment tree seg[].
*/
int query(int node, int start,
          int end, int l, int r)
{
    if (r < start || end < l) {
        return 0;
    }
 
    if (l <= start && end <= r) {
        return seg[node];
    }
 
    int mid = (start + end) / 2;
    int p1 = query(2 * node, start,
                   mid, l, r);
    int p2 = query(2 * node + 1, mid + 1,
                   end, l, r);
 
    return (p1 + p2);
}
 
/* A recursive function to update the
   nodes which have the given index in
   their range. The following are
   parameters pos --> index of current
   node in segment tree seg[]. idx -->
   index of the element to be updated.
   This index is in input array.
   val --> Value to be change at node idx
*/
int update(int pos, int low, int high,
           int idx, int val)
{
    if (low == high) {
        seg[pos] = val;
    }
    else {
        int mid = (low + high) / 2;
 
        if (low <= idx && idx <= mid) {
            update(2 * pos, low, mid,
                   idx, val);
        }
        else {
            update(2 * pos + 1, mid + 1,
                   high, idx, val);
        }
 
        seg[pos] = seg[2 * pos] + seg[2 * pos + 1];
    }
}
 
/* A recursive function to form array
    ar[] from a directed tree .
*/
int dfs(int root)
{
    // pushing each node in vector s
    s.push_back(root);
    if (v[root].size() == 0)
        return root;
 
    for (int i = 0; i < v[root].size(); i++) {
        int temp = dfs(v[root][i]);
        s.push_back(temp);
    }
    return root;
}
 
// Driver program to test above functions
int main()
{
    // Edges between the nodes
    v[1].push_back(2);
    v[1].push_back(3);
    v[2].push_back(6);
    v[2].push_back(5);
    v[3].push_back(4);
 
    // Calling dfs function.
    int temp = dfs(1);
    s.push_back(temp);
 
    // Storing entry time and exit
    // time of each node
    vector > p;
 
    for (int i = 0; i <= vertices; i++)
        p.push_back(make_pair(0, 0));
 
    for (int i = 0; i < s.size(); i++) {
        if (p[s[i]].first == 0)
            p[s[i]].first = i + 1;
        else
            p[s[i]].second = i + 1;
    }
 
    // Build segment tree from array ar[].
    segment(0, s.size() - 1, 1);
 
    // query of type 1 return the
    // sum of subtree at node 1.
    int node = 1;
    int e = p[node].first;
    int f = p[node].second;
 
    int ans = query(1, 1, s.size(), e, f);
 
    // print the sum of subtree
    cout << "Subtree sum of node " << node << " is : " << (ans / 2) << endl;
 
    // query of type 2 return update
    // the subtree at node 6.
    int val = 10;
    node = 6;
 
    e = p[node].first;
    f = p[node].second;
    update(1, 1, s.size(), e, val);
    update(1, 1, s.size(), f, val);
 
    // query of type 1 return the
    // sum of subtree at node 2.
    node = 2;
 
    e = p[node].first;
    f = p[node].second;
 
    ans = query(1, 1, s.size(), e, f);
 
    // print the sum of subtree
    cout << "Subtree sum of node " << node << " is : " << (ans / 2) << endl;
 
    return 0;
}

Java

// Java program for implementation of
// Euler Tour | Subtree Sum.
import java.util.ArrayList;
 
class Graph{
 
static class Pair
{
    int first, second;
     
    public Pair(int first, int second)
    {
        this.first = first;
        this.second = second;
    }
}
 
@SuppressWarnings("unchecked")
static ArrayList[] v = new ArrayList[1001];
static ArrayList s = new ArrayList<>();
static int[] seg = new int[1001];
 
// Value/Weight of each node of tree,
// value of 0th(no such node) node is 0.
static int ar[] = { 0, 1, 2, 3, 4, 5, 6 };
 
static int vertices = 6;
static int edges = 5;
 
// A recursive function that constructs
// Segment Tree for array ar[] = { }.
// 'pos' is index of current node
// in segment tree seg[].
static void segment(int low, int high, int pos)
{
    if (high == low)
    {
        seg[pos] = ar[s.get(low)];
    }
    else
    {
        int mid = (low + high) / 2;
        segment(low, mid, 2 * pos);
        segment(mid + 1, high, 2 * pos + 1);
        seg[pos] = seg[2 * pos] + seg[2 * pos + 1];
    }
}
 
// Return sum of elements in range from index
// l to r. It uses the seg[] array created
// using segment() function. 'pos' is index
// of current node in segment tree seg[].
static int query(int node, int start,
                 int end, int l, int r)
{
    if (r < start || end < l)
    {
        return 0;
    }
 
    if (l <= start && end <= r)
    {
        return seg[node];
    }
 
    int mid = (start + end) / 2;
    int p1 = query(2 * node, start,
                       mid, l, r);
    int p2 = query(2 * node + 1, mid + 1,
                       end, l, r);
 
    return (p1 + p2);
}
 
/*
 * A recursive function to update the nodes which
 * have the given index in their range.
 * The following are parameters
 * pos --> index of current node in segment tree seg[].
 * idx --> index of the element to be updated.
 *         This index is in input array.
 * val --> Value to be change at node idx
 */
static void update(int pos, int low, int high,
                   int idx, int val)
{
    if (low == high)
    {
        seg[pos] = val;
    }
    else
    {
        int mid = (low + high) / 2;
 
        if (low <= idx && idx <= mid)
        {
            update(2 * pos, low, mid,
                       idx, val);
        }
        else
        {
            update(2 * pos + 1, mid + 1,
                       high, idx, val);
        }
 
        seg[pos] = seg[2 * pos] +
                   seg[2 * pos + 1];
    }
}
 
// A recursive function to form array
// ar[] from a directed tree.
static int dfs(int root)
{
     
    // Pushing each node in ArrayList s
    s.add(root);
     
    if (v[root].size() == 0)
        return root;
 
    for(int i = 0; i < v[root].size(); i++)
    {
        int temp = dfs(v[root].get(i));
        s.add(temp);
    }
    return root;
}
 
// Driver code
public static void main(String[] args)
{
    for(int i = 0; i < 1001; i++)
    {
        v[i] = new ArrayList<>();
    }
 
    // Edges between the nodes
    v[1].add(2);
    v[1].add(3);
    v[2].add(6);
    v[2].add(5);
    v[3].add(4);
 
    // Calling dfs function.
    int temp = dfs(1);
    s.add(temp);
 
    // Storing entry time and exit
    // time of each node
    ArrayList p = new ArrayList<>();
 
    for(int i = 0; i <= vertices; i++)
        p.add(new Pair(0, 0));
 
    for(int i = 0; i < s.size(); i++)
    {
        if (p.get(s.get(i)).first == 0)
            p.get(s.get(i)).first = i + 1;
        else
            p.get(s.get(i)).second = i + 1;
    }
 
    // Build segment tree from array ar[].
    segment(0, s.size() - 1, 1);
 
    // Query of type 1 return the
    // sum of subtree at node 1.
    int node = 1;
    int e = p.get(node).first;
    int f = p.get(node).second;
 
    int ans = query(1, 1, s.size(), e, f);
 
    // Print the sum of subtree
    System.out.println("Subtree sum of node " +
                       node + " is : " + (ans / 2));
 
    // Query of type 2 return update
    // the subtree at node 6.
    int val = 10;
    node = 6;
 
    e = p.get(node).first;
    f = p.get(node).second;
    update(1, 1, s.size(), e, val);
    update(1, 1, s.size(), f, val);
 
    // Query of type 1 return the
    // sum of subtree at node 2.
    node = 2;
 
    e = p.get(node).first;
    f = p.get(node).second;
 
    ans = query(1, 1, s.size(), e, f);
 
    // Print the sum of subtree
    System.out.println("Subtree sum of node " +
                       node + " is : " + (ans / 2));
}
}
 
// This code is contributed by sanjeev2552

Python3

# Python3 program for implementation of
# Euler Tour | Subtree Sum.
v = [[] for i in range(1001)]
s = []
seg = [0]*1001
 
# Value/Weight of each node of tree,
# value of 0th(no such node) node is 0.
ar = [0, 1, 2, 3, 4, 5, 6]
 
vertices = 6
edges = 5
 
# A recursive function that constructs
# Segment Tree for array ar = .
# 'pos' is index of current node
# in segment tree seg.
def segment(low, high, pos):
    if (high == low):
        seg[pos] = ar[s[low]]
    else:
        mid = (low + high) // 2
        segment(low, mid, 2 * pos)
        segment(mid + 1, high, 2 * pos + 1)
        seg[pos] = seg[2 * pos] + seg[2 * pos + 1]
         
''' Return sum of elements in range
from index l to r . It uses the
seg array created using segment()
function. 'pos' is index of current
node in segment tree seg.
'''
def query(node, start,end, l, r):
    if (r < start or end < l):
        return 0
     
    if (l <= start and end <= r):
        return seg[node]
         
    mid = (start + end) // 2
    p1 = query(2 * node, start,mid, l, r)
    p2 = query(2 * node + 1, mid + 1, end, l, r)
     
    return (p1 + p2)
 
''' A recursive function to update the
nodes which have the given index in
their range. The following are
parameters pos --> index of current
node in segment tree seg. idx -->
index of the element to be updated.
This index is in input array.
val --> Value to be change at node idx
'''
def update(pos, low, high, idx, val):
    if (low == high):
        seg[pos] = val
    else:
        mid = (low + high) // 2
        if (low <= idx and idx <= mid):
            update(2 * pos, low, mid,idx, val)
        else:
            update(2 * pos + 1, mid + 1, high, idx, val)
         
        seg[pos] = seg[2 * pos] + seg[2 * pos + 1]
     
''' A recursive function to form array
    ar from a directed tree .
'''
def dfs(root):
     
    # pushing each node in vector s
    s.append(root)
    if (len(v[root]) == 0):
        return root
     
    for i in range(len(v[root])):
        temp = dfs(v[root][i])
        s.append(temp)
     
    return root
 
# Edges between the nodes
v[1].append(2)
v[1].append(3)
v[2].append(6)
v[2].append(5)
v[3].append(4)
 
# Calling dfs function.
temp = dfs(1)
s.append(temp)
 
# Storing entry time and exit
# time of each node
p = []
 
for i in range(vertices + 1):
    p.append([0, 0])
 
for i in range(len(s)):
    if (p[s[i]][0] == 0):
        p[s[i]][0] = i + 1
    else:
        p[s[i]][1] = i + 1
 
# Build segment tree from array ar.
segment(0, len(s) - 1, 1)
 
# query of type 1 return the
# sum of subtree at node 1.
node = 1
e = p[node][0]
f = p[node][1]
 
ans = query(1, 1, len(s), e, f)
 
# print the sum of subtree
print("Subtree sum of node", node, "is :", ans // 2)
 
# query of type 2 return update
# the subtree at node 6.
val = 10
node = 6
 
e = p[node][0]
f = p[node][1]
update(1, 1, len(s), e, val)
update(1, 1, len(s), f, val)
 
# query of type 1 return the
# sum of subtree at node 2.
node = 2
 
e = p[node][0]
f = p[node][1]
 
ans = query(1, 1, len(s), e, f)
 
# print the sum of subtree
print("Subtree sum of node",node,"is :",ans // 2)
 
# This code is contributed by SHUBHAMSINGH10
输出:
Subtree sum of node 1 is : 21
Subtree sum of node 2 is : 17

时间复杂度: O(q * log(n))