📜  段树| (给定范围模的乘积)

📅  最后修改于: 2021-04-17 12:17:21             🧑  作者: Mango

让我们考虑以下问题以理解段树。
我们有一个数组arr [0。 。 。 n-1]。我们应该能够
1找到从索引l到r的元素的乘积,其中0 <= l <= r <= n-1的模数为整数m。
2将数组的指定元素的值更改为新值x。我们需要做arr [i] = x,其中0 <= i <= n-1。

一个简单的解决方案是运行从l到r的循环,计算给定范围内的元素乘积,然后以m为模。要更新值,只需做arr [i] = x。第一次操作花费O(n)时间,第二次操作花费O(1)时间。

另一种解决方案是创建两个数组,并在第一个数组中存储从m的开头到l-1的模乘积,在另一个数组中存储m的从r + 1到结尾的模m的乘积。现在可以以O(1)的时间计算给定范围的乘积,但是现在更新操作需要O(n)的时间。
假设所有元素的乘积为P,则从给定范围l到r的乘积P可以计算为:
P:模m数组中所有元素的乘积。
A:直到l-1模m的所有元素的乘积。
B:直到r + 1模m的所有元素的乘积。
PDT = P *(modInverse(A))*(modInverse(B))
如果查询操作的数量很大且更新很少,则此方法效果很好。

段树解决方案:
如果查询和更新的次数相等,则可以在O(log n)时间内执行这两个操作。我们可以使用细分树在O(Logn)时间内完成这两项操作。

段树的表示
1.叶节点是输入数组的元素。
2.每个内部节点代表叶节点的某些合并。对于不同的问题,合并可能会有所不同。对于此问题,合并是节点下叶子的产物。
树的数组表示形式用于表示段树。对于索引i处的每个节点,左子节点在索引2 * i + 1处,右子节点在索引2 * i + 2处,父节点在(i-1)/ 2处。

查询给定范围的乘积
构造树后,如何使用构造的分段树获取产品。以下是获取元素乘积的算法。

int getPdt(node, l, r) 
{
   if range of node is within l and r
        return value in node
   else if range of node is completely outside l and r
        return 1
   else
    return (getPdt(node's left child, l, r)%mod * 
           getPdt(node's right child, l, r)%mod)%mod
}

更新值
像树构建和查询操作一样,更新也可以递归进行。给我们一个需要更新的索引。我们从细分树的根开始,将范围乘积乘以新值,然后将范围乘积除以先前的值。如果节点在其范围内没有给定的索引,则我们不会对该节点进行任何更改。

执行:
以下是段树的实现。该程序为任何给定的数组实现了段树的构造。它还实现了查询和更新操作。

C++
// C++ program to show segment tree operations like
// construction, query and update
#include 
#include 
using namespace std;
int mod = 1000000000;
  
// A utility function to get the middle index from
// corner indexes.
int getMid(int s, int e) {  return s + (e -s)/2;  }
  
/*  A recursive function to get the Pdt of values
    in given range of the array. The following are
    parameters for this function.
  
    st    --> Pointer to segment tree
    si    --> Index of current node in the segment tree.
              Initially 0 is passed as root is always
              at index 0
    ss & se  --> Starting and ending indexes of the
                 segment represented by current node,
                 i.e., st[si]
    qs & qe  --> Starting and ending indexes of query
                 range */
int getPdtUtil(int *st, int ss, int se, int qs, int qe,
                                                int si)
{
    // If segment of this node is a part of given
    // range, then return the Pdt of the segment
    if (qs <= ss && qe >= se)
        return st[si];
  
    // If segment of this node is outside the given range
    if (se < qs || ss > qe)
        return 1;
  
    // If a part of this segment overlaps with the
    // given range
    int mid = getMid(ss, se);
    return (getPdtUtil(st, ss, mid, qs, qe, 2*si+1)%mod *
           getPdtUtil(st, mid+1, se, qs, qe, 2*si+2)%mod)%mod;
}
  
/* A recursive function to update the nodes which have
   the given index in their range. The following are
   parameters
    st, si, ss and se are same as getPdtUtil()
    i    --> index of the element to be updated. 
             This index is in input array.*/   
void updateValueUtil(int *st, int ss, int se, int i,
                        int prev_val, int new_val, int si)
{
    // Base Case: If the input index lies outside
    // the range of  this segment
    if (i < ss || i > se)
        return;
  
    // If the input index is in range of this node, then 
    // update the value of the node and its children
    st[si] = (st[si]*new_val)/prev_val;
    if (se != ss)
    {
        int mid = getMid(ss, se);
        updateValueUtil(st, ss, mid, i, prev_val,
                                new_val, 2*si + 1);
        updateValueUtil(st, mid+1, se, i, prev_val,
                                new_val, 2*si + 2);
    }
}
  
// The function to update a value in input array
// and segment tree. It uses updateValueUtil() to
// update the value in segment tree
void updateValue(int arr[], int *st, int n, int i,
                                      int new_val)
{
    // Check for erroneous input index
    if (i < 0 || i > n-1)
    {
        cout<<"Invalid Input";
        return;
    }
    int temp = arr[i];
 
    // Update the value in array
    arr[i] = new_val;
  
    // Update the values of nodes in segment tree
    updateValueUtil(st, 0, n-1, i, temp, new_val, 0);
}
  
// Return Pdt of elements in range from index qs
// (query start)to qe (query end).  It mainly
// uses getPdtUtil()
int getPdt(int *st, int n, int qs, int qe)
{
    // Check for erroneous input values
    if (qs < 0 || qe > n-1 || qs > qe)
    {
        cout<<"Invalid Input";
        return -1;
    }
  
    return getPdtUtil(st, 0, n-1, qs, qe, 0);
}
  
// A recursive function that constructs Segment Tree
// for array[ss..se]. si is index of current node
// in segment tree st
int constructSTUtil(int arr[], int ss, int se,
                              int *st, int si)
{
    // If there is one element in array, store it
    // in current node of segment tree and return
    if (ss == se)
    {
        st[si] = arr[ss];
        return arr[ss];
    }
  
    // If there are more than one elements, then
    // recur for left and right subtrees and store
    // the Pdt of values in this node
    int mid = getMid(ss, se);
    st[si] =  (constructSTUtil(arr, ss, mid, st, si*2+1)%mod *
              constructSTUtil(arr, mid+1, se, st, si*2+2)%mod)%mod;
    return st[si];
}
  
/* Function to construct segment tree from given array.
   This function allocates memory for segment tree and
   calls constructSTUtil() to fill the allocated memory */
int *constructST(int arr[], int n)
{
    // Allocate memory for segment tree
  
    // Height of segment tree
    int x = (int)(ceil(log2(n)));
  
    // Maximum size of segment tree
    int max_size = 2*(int)pow(2, x) - 1;
  
    // Allocate memory
    int *st = new int[max_size];
  
    // Fill the allocated memory st
    constructSTUtil(arr, 0, n-1, st, 0);
  
    // Return the constructed segment tree
    return st;
}
  
// Driver program to test above functions
int main()
{
    int arr[] = {1, 2, 3, 4, 5, 6};
    int n = sizeof(arr)/sizeof(arr[0]);
 
    // Build segment tree from given array
    int *st = constructST(arr, n);
  
    // Print Product of values in array from index 1 to 3
    cout << "Product of values in given range = "
         << getPdt(st, n, 1, 3) << endl;
            
    // Update: set arr[1] = 10 and update corresponding
    // segment tree nodes
    updateValue(arr, st, n, 1, 10);
  
    // Find Product after the value is updated
    cout << "Updated Product of values in given range = "
         << getPdt(st, n, 1, 3) << endl;
    return 0;
}


Java
// Java program to show segment tree operations
// like construction, query and update
class GFG{
 
static final int mod = 1000000000;
 
// A utility function to get the middle
// index from corner indexes.
static int getMid(int s, int e)
{
    return s + (e - s) / 2;
}
 
/*
 * A recursive function to get the Pdt of values
 * in given range of the array.
 * The following are parameters for this function.
 *
 * st --> Pointer to segment tree
 * si --> Index of current node in the segment tree.
 *        Initially 0 is passed as root is always
 *        at index 0
 * ss & se --> Starting and ending indexes of the
 *             segment represented by current node,
 *             i.e., st[si]
 * qs & qe --> Starting and ending indexes of query range
 */
static int getPdtUtil(int[] st, int ss, int se,
                      int qs, int qe, int si)
{
     
    // If segment of this node is a part of given
    // range, then return the Pdt of the segment
    if (qs <= ss && qe >= se)
        return st[si];
 
    // If segment of this node is outside
    // the given range
    if (se < qs || ss > qe)
        return 1;
 
    // If a part of this segment overlaps
    // with the given range
    int mid = getMid(ss, se);
    return (getPdtUtil(st, ss, mid, qs,
                       qe, 2 * si + 1) % mod *
           getPdtUtil(st, mid + 1, se, qs,
                      qe, 2 * si + 2) % mod) % mod;
}
 
/*
 * A recursive function to update the nodes which have
 * the given index in their range. The following are
 * parameters
 * st, si, ss and se are same as getPdtUtil()
 * i --> index of the element to be updated.
 *        This index is in input array.
 */
static void updateValueUtil(int[] st, int ss, int se,
                            int i, int prev_val,
                            int new_val, int si)
{
     
    // Base Case: If the input index lies outside
    // the range of this segment
    if (i < ss || i > se)
        return;
 
    // If the input index is in range of this node, then
    // update the value of the node and its children
    st[si] = (st[si] * new_val) / prev_val;
    if (se != ss)
    {
        int mid = getMid(ss, se);
        updateValueUtil(st, ss, mid, i, prev_val,
                        new_val, 2 * si + 1);
        updateValueUtil(st, mid + 1, se, i, prev_val,
                        new_val, 2 * si + 2);
    }
}
 
// The function to update a value in input array
// and segment tree. It uses updateValueUtil() to
// update the value in segment tree
static void updateValue(int arr[], int[] st, int n,
                        int i, int new_val)
{
     
    // Check for erroneous input index
    if (i < 0 || i > n - 1)
    {
        System.out.println("Invalid Input");
        return;
    }
    int temp = arr[i];
 
    // Update the value in array
    arr[i] = new_val;
 
    // Update the values of nodes in segment tree
    updateValueUtil(st, 0, n - 1, i,
                    temp, new_val, 0);
}
 
// Return Pdt of elements in range from index qs
// (query start)to qe (query end). It mainly
// uses getPdtUtil()
static int getPdt(int[] st, int n, int qs, int qe)
{
     
    // Check for erroneous input values
    if (qs < 0 || qe > n - 1 || qs > qe)
    {
        System.out.println("Invalid Input");
        return -1;
    }
 
    return getPdtUtil(st, 0, n - 1, qs, qe, 0);
}
 
// A recursive function that constructs Segment Tree
// for array[ss..se]. si is index of current node
// in segment tree st
static int constructSTUtil(int arr[], int ss, int se,
                           int[] st, int si)
{
     
    // If there is one element in array, store it
    // in current node of segment tree and return
    if (ss == se)
    {
        st[si] = arr[ss];
        return arr[ss];
    }
 
    // If there are more than one elements, then
    // recur for left and right subtrees and store
    // the Pdt of values in this node
    int mid = getMid(ss, se);
    st[si] = (constructSTUtil(arr, ss, mid, st,
                              si * 2 + 1) % mod *
              constructSTUtil(arr, mid + 1, se, st,
                              si * 2 + 2) % mod) % mod;
    return st[si];
}
 
/*
 * Function to construct segment tree from
 * given array. This function allocates memory
 * for segment tree and calls constructSTUtil()
 * to fill the allocated memory
 */
static int[] constructST(int arr[], int n)
{
     
    // Allocate memory for segment tree
 
    // Height of segment tree
    int x = (int)(Math.ceil(Math.log(n) /
                            Math.log(2)));
 
    // Maximum size of segment tree
    int max_size = 2 * (int)Math.pow(2, x) - 1;
 
    // Allocate memory
    int[] st = new int[max_size];
 
    // Fill the allocated memory st
    constructSTUtil(arr, 0, n - 1, st, 0);
 
    // Return the constructed segment tree
    return st;
}
 
// Driver code
public static void main(String[] args)
{
    int arr[] = { 1, 2, 3, 4, 5, 6 };
    int n = arr.length;
 
    // Build segment tree from given array
    int[] st = constructST(arr, n);
 
    // Print Product of values in array from
    // index 1 to 3
    System.out.printf("Product of values in " +
                      "given range = %d\n",
                      getPdt(st, n, 1, 3));
 
    // Update: set arr[1] = 10 and update
    // corresponding segment tree nodes
    updateValue(arr, st, n, 1, 10);
 
    // Find Product after the value is updated
    System.out.printf("Updated Product of values " +
                      "in given range = %d\n",
                      getPdt(st, n, 1, 3));
}
}
 
// This code is contributed by sanjeev2552


Python3
# Python3 program to show segment tree operations like
# construction, query and update
from math import ceil,log
mod = 1000000000
 
# A utility function to get the middle index from
# corner indexes.
def getMid(s, e):
    return s + (e -s)//2
 
"""A recursive function to get the Pdt of values
    in given range of the array. The following are
    parameters for this function.
 
    st --> Pointer to segment tree
    si --> Index of current node in the segment tree.
            Initially 0 is passed as root is always
            at index 0
    ss & se --> Starting and ending indexes of the
                segment represented by current node,
                i.e., st[si]
    qs & qe --> Starting and ending indexes of query
                range"""
def getPdtUtil(st, ss, se, qs, qe,si):
     
    # If segment of this node is a part of given
    # range, then return the Pdt of the segment
    if (qs <= ss and qe >= se):
        return st[si]
 
    # If segment of this node is outside the given range
    if (se < qs or ss > qe):
        return 1
 
    # If a part of this segment overlaps with the
    # given range
    mid = getMid(ss, se)
    return (getPdtUtil(st, ss, mid, qs, qe, 2*si+1)%mod*
        getPdtUtil(st, mid+1, se, qs, qe, 2*si+2)%mod)%mod
"""A recursive function to update the nodes which have
the given index in their range. The following are
parameters
    st, si, ss and se are same as getPdtUtil()
    i --> index of the element to be updated.
            This index is in input array."""
def updateValueUtil(st, ss, se, i, prev_val, new_val, si):
     
    # Base Case: If the input index lies outside
    # the range of this segment
    if (i < ss or i > se):
        return
 
    # If the input index is in range of this node, then
    # update the value of the node and its children
    st[si] = (st[si]*new_val)//prev_val
    if (se != ss):
        mid = getMid(ss, se)
        updateValueUtil(st, ss, mid, i, prev_val,
                                new_val, 2*si + 1)
        updateValueUtil(st, mid+1, se, i, prev_val,
                                new_val, 2*si + 2)
 
 
# The function to update a value in input array
# and segment tree. It uses updateValueUtil() to
# update the value in segment tree
def updateValue(arr, st, n, i, new_val):
     
    # Check for erroneous input index
    if (i < 0 or i > n-1):
        cout<<"Invalid Input"
        return
    temp = arr[i]
 
    # Update the value in array
    arr[i] = new_val
 
    # Update the values of nodes in segment tree
    updateValueUtil(st, 0, n-1, i, temp, new_val, 0)
 
# Return Pdt of elements in range from index qs
# (query start)to qe (query end). It mainly
# uses getPdtUtil()
def getPdt(st, n, qs, qe):
     
    # Check for erroneous input values
    if (qs < 0 or qe > n-1 or qs > qe):
        print("Invalid Input")
        return -1
 
    return getPdtUtil(st, 0, n-1, qs, qe, 0)
 
# A recursive function that constructs Segment Tree
# for array[ss..se]. si is index of current node
# in segment tree st
def constructSTUtil(arr, ss, se,st, si):
     
    # If there is one element in array, store it
    # in current node of segment tree and return
    if (ss == se):
        st[si] = arr[ss]
        return arr[ss]
 
    # If there are more than one elements, then
    # recur for left and right subtrees and store
    # the Pdt of values in this node
    mid = getMid(ss, se)
    st[si] = (constructSTUtil(arr, ss, mid, st, si*2+1)%mod*
            constructSTUtil(arr, mid+1, se, st, si*2+2)%mod)%mod
    return st[si]
 
 
""" Function to construct segment tree from given array.
This function allocates memory for segment tree and
calls constructSTUtil() to fill the allocated memory
"""
def constructST(arr, n):
    # Allocate memory for segment tree
 
    # Height of segment tree
    x = (ceil(log(n,2)))
 
    # Maximum size of segment tree
    max_size = 2*pow(2, x) - 1
 
    # Allocate memory
    st = [0]*max_size
 
    # Fill the allocated memory st
    constructSTUtil(arr, 0, n-1, st, 0)
 
    # Return the constructed segment tree
    return st
 
# Driver program to test above functions
if __name__ == '__main__':
    arr=[1, 2, 3, 4, 5, 6]
    n = len(arr)
 
    # Build segment tree from given array
    st = constructST(arr, n)
 
    # PrProduct of values in array from index 1 to 3
    print("Product of values in given range = ",getPdt(st, n, 1, 3))
 
    # Update: set arr[1] = 10 and update corresponding
    # segment tree nodes
    updateValue(arr, st, n, 1, 10)
 
    # Find Product after the value is updated
    print("Updated Product of values in given range = ",getPdt(st, n, 1, 3))
 
# This code is contributed by mohit kumar 29


C#
// C# program to show segment tree operations
// like construction, query and update
using System;
class GFG
{
    static int mod = 1000000000;
   
    // A utility function to get the middle
    // index from corner indexes.
    public static int getMid(int s, int e)
    {
        return s + (e - s) / 2;
    }
      
/*
 * A recursive function to get the Pdt of values
 * in given range of the array.
 * The following are parameters for this function.
 *
 * st --> Pointer to segment tree
 * si --> Index of current node in the segment tree.
 *        Initially 0 is passed as root is always
 *        at index 0
 * ss & se --> Starting and ending indexes of the
 *             segment represented by current node,
 *             i.e., st[si]
 * qs & qe --> Starting and ending indexes of query range
 */
    public static int getPdtUtil(int[] st, int ss,
                                 int se,int qs,
                                 int qe, int si)
    {
       
        // If segment of this node is a part of given
        // range, then return the Pdt of the segment
        if(qs <= ss && qe >= se)
        {
            return st[si];
        }
       
        // If segment of this node is outside
        // the given range
        if(se < qs || ss > qe)
        {
            return 1;
        }
       
        // If a part of this segment overlaps
        // with the given range
        int mid=getMid(ss, se);
        return (getPdtUtil(st, ss, mid, qs,qe, 2 * si + 1) % mod *
                getPdtUtil(st, mid + 1, se, qs,qe, 2 * si + 2) % mod) % mod;
    }
   
    /*
    * A recursive function to update the nodes which have
    * the given index in their range. The following are
    * parameters
    * st, si, ss and se are same as getPdtUtil()
    * i --> index of the element to be updated.
    *        This index is in input array.
    */
    public static void updateValueUtil(int[] st, int ss,
                                       int se, int i,
                                       int prev_val,
                                       int new_val, int si)
    {
       
        // Base Case: If the input index lies outside
        // the range of this segment
        if(i < ss || i > se)
        {
            return;
        }
       
        // If the input index is in range of this node, then
        // update the value of the node and its children
        st[si] = (st[si] * new_val) / prev_val;
        if (se != ss)
        {
            int mid = getMid(ss, se);
            updateValueUtil(st, ss, mid, i, prev_val,new_val, 2 * si + 1);
            updateValueUtil(st, mid + 1, se, i, prev_val,new_val, 2 * si + 2);
        }
    }
   
    // The function to update a value in input array
    // and segment tree. It uses updateValueUtil() to
    // update the value in segment tree
    public static void updateValue(int[] arr, int[] st,
                                   int n,int i, int new_val)
    {
       
        // Check for erroneous input index
        if(i < 0 || i > n - 1)
        {
            Console.WriteLine("Invalid Input");
            return;
        }
        int temp = arr[i];
       
        // Update the value in array
        arr[i] = new_val;
       
        // Update the values of nodes in segment tree
        updateValueUtil(st, 0, n - 1, i, temp, new_val, 0);
         
    }
   
    // Return Pdt of elements in range from index qs
    // (query start)to qe (query end). It mainly
    // uses getPdtUtil()
    public static int getPdt(int[] st, int n, int qs, int qe)
    {
       
        // Check for erroneous input values
        if(qs < 0 || qe > n - 1 || qs > qe)
        {
            Console.WriteLine("Invalid Input");
            return -1;
        }
        return getPdtUtil(st, 0, n - 1, qs, qe, 0);
    }
   
    // A recursive function that constructs Segment Tree
    // for array[ss..se]. si is index of current node
    // in segment tree st
    public static int constructSTUtil(int[] arr, int ss,
                                      int se,int[] st, int si)
    {
       
        // If there is one element in array, store it
        // in current node of segment tree and return
        if (ss == se)
        {
            st[si] = arr[ss];
            return arr[ss];
        }
       
        // If there are more than one elements, then
        // recur for left and right subtrees and store
        // the Pdt of values in this node
        int mid = getMid(ss, se);
        st[si] = (constructSTUtil(arr, ss, mid, st, si * 2 + 1) % mod *
                  constructSTUtil(arr, mid + 1, se, st,si * 2 + 2) % mod) % mod;
        return st[si];
    }
    /*
    * Function to construct segment tree from
    * given array. This function allocates memory
    * for segment tree and calls constructSTUtil()
    * to fill the allocated memory
    */
    public static int[] constructST(int[] arr, int n)
    {
       
        // Allocate memory for segment tree
  
        // Height of segment tree
        int x = (int)(Math.Ceiling(Math.Log(n) /Math.Log(2)));
       
        // Maximum size of segment tree
        int max_size = 2 * (int)Math.Pow(2, x) - 1;
       
        // Allocate memory
        int[] st = new int[max_size];
       
        // Fill the allocated memory st
        constructSTUtil(arr, 0, n - 1, st, 0);
       
        // Return the constructed segment tree
        return st;
    }
   
    // Driver code
    static public void Main ()
    {
       int[] arr = { 1, 2, 3, 4, 5, 6 };
       int n = arr.Length;
       
       // Build segment tree from given array
       int[] st = constructST(arr, n);
       
       // Print Product of values in array from
        // index 1 to 3
       Console.WriteLine("Product of values in " +
                         "given range = " + getPdt(st, n, 1, 3));
       
       // Update: set arr[1] = 10 and update
        // corresponding segment tree nodes
       updateValue(arr, st, n, 1, 10);
       
       // Find Product after the value is updated
       Console.WriteLine("Updated Product of values " +
                         "in given range = " + getPdt(st, n, 1, 3));
    }
}
 
// This code is contributed by avanitrachhadiya2155


输出:

Product of values in given range = 24
Updated Product of values in given range = 120