📜  查找每个给定 N 个区间右侧最近的非重叠区间的索引(1)

📅  最后修改于: 2023-12-03 15:26:38.586000             🧑  作者: Mango

查找每个给定 N 个区间右侧最近的非重叠区间的索引

在进行区间问题的解决时,有时需要查找每个给定的N个区间右侧最近的非重叠区间的索引。这个问题可以使用贪心算法和线段树结合的方式来解决。

贪心算法

对于每个区间,我们可以记录下它的右端点,然后按照右端点从小到大排序。接下来,我们对于每个区间,找到它右侧第一个不与它重叠的区间。这样,我们就可以获取每个区间右侧最近的非重叠区间的索引。

这个算法的时间复杂度为$O(nlogn)$,其中$n$为输入区间的数量,主要时间消耗在排序上。

线段树结合贪心算法

线段树是一种高效的区间查询算法,它可以在$logn$时间内找到给定区间的某种特定值。利用线段树可以进一步提高上述算法的效率。

首先,我们还是根据区间右端点的大小对所有区间进行排序。接着,我们使用线段树来维护每个区间右侧第一个不与它重叠的区间的索引。

我们使用一个数组$tree$来存储线段树的结构,其中$tree[i]$表示第$i$个节点所覆盖的区间。对于每个节点$i$,我们维护两个值$left$和$right$。$left$表示区间$[tree[i].left,tree[i].right]$中最近的右侧不与其重叠的区间的索引,$right$表示区间$[tree[i].left,tree[i].right]$中最右侧区间的索引。对于每个区间$j$,我们在线段树中找到对应的区间$i$,如果区间$[tree[i].left,tree[i].right]$与区间$j$重叠,则继续查找它右侧的区间。若找到区间$k$与区间$j$不重叠,则更新$left$值。

以上算法的时间复杂度为$O(nlog^2n)$。

代码实现
算法一
def find_non_overlapping_interval(n: int, intervals: List[Tuple[int, int]]) -> List[int]:
    right_end_points = sorted([interval[1] for interval in intervals])
    result = [-1] * n
    
    for i in range(n):
        for j in range(i + 1, n):
            if intervals[j][0] >= intervals[i][1]:
                result[i] = j
                break
    
    return result
算法二
class IntervalNode:
    def __init__(self, left: int, right: int) -> None:
        self.left = left
        self.right = right
        self.left_node = None
        self.right_node = None
        self.left_index = -1
        self.right_index = -1


def insert(node: IntervalNode, left: int, right: int, index: int) -> None:
    if node.left >= right or node.right <= left:
        return
    
    if left <= node.left and right >= node.right:
        node.left_index = index
    elif node.left_node or node.right_node:
        insert(node.left_node, left, right, index)
        insert(node.right_node, left, right, index)
        
        left_index_left = node.left_node.left_index
        left_index_right = (node.left_node.right_index if left_index_left != -1 else node.right_node.left_index)
        right_index_right = node.right_node.right_index
        right_index_left = (node.right_node.left_index if right_index_right != -1 else node.left_node.right_index)
        
        if left_index_left == -1:
            node.left_index = left_index_right
        elif right_index_right == -1:
            node.right_index = right_index_left
        else:
            if intervals[left_index_right][0] - intervals[left_index_left][1] <= intervals[right_index_left][0] - intervals[right_index_right][1]:
                node.left_index = left_index_right
            else:
                node.right_index = right_index_left
    else:
        mid = (node.left + node.right) // 2
        node.left_node = IntervalNode(node.left, mid)
        node.right_node = IntervalNode(mid, node.right)
        insert(node.left_node, left, right, index)
        insert(node.right_node, left, right, index)


def find(node: IntervalNode, left: int, right: int) -> int:
    result = -1
    
    if node.left_index != -1 and intervals[node.left_index][0] >= right:
        return node.left_index
    if node.right_index != -1 and intervals[node.right_index][0] >= right:
        return node.right_index
    
    if node.left_node:
        result_left = find(node.left_node, left, right)
        result_right = find(node.right_node, left, right)
        
        if result_left != -1 and result_right != -1:
            if intervals[result_left][0] - intervals[result][1] <= intervals[result_right][0] - intervals[result_right][1]:
                result = result_left
            else:
                result = result_right
        elif result_left != -1:
            result = result_left
        else:
            result = result_right
    
    return result


def find_non_overlapping_interval(n: int, intervals: List[Tuple[int, int]]) -> List[int]:
    right_end_points = sorted([interval[1] for interval in intervals])
    root = IntervalNode(0, right_end_points[-1] + 1)
    result = [-1] * n
    
    for i in range(n):
        insert(root, intervals[i][0], intervals[i][1], i)
    
    for i in range(n):
        result[i] = find(root, intervals[i][1], right_end_points[-1] + 1)
        
    return result