📜  门| GATE-CS-2016(套装2)|第 31 题(1)

📅  最后修改于: 2023-12-03 15:28:44.996000             🧑  作者: Mango

题目描述

有一颗二叉树,定义其完美权值平衡(Pefectly Weight Balanced, 简称PWB)当且仅当任意一个节点的权值为其所有子树内所有权值的平均数。现有一个长度为 $n$ 的序列 $a_1, a_2, \cdots, a_n$,每个数字代表一个二叉树的节点权值,判断是否存在一个PWB二叉树使得其所有节点都可以映射到 $a_1, a_2, \cdots, a_n$ 中的某个数字,并求出可能的方案数。

输入格式

第一行为整数 $T$,代表有 $T$ 组测试数据。每组测试数据的第一行为 $n$,表示二叉树元素个数。接下来一行有 $n$ 个整数,具体代表 $a_1, a_2, \cdots, a_n$。

输出格式

对于每个测试数据,首先输出“Case #”和当前测试数据组数,然后输出一行代表方案数的整数。

数据范围

$1 \leq T \leq 10^5$, $2 \leq n \leq 10^6$, $-10^9 \leq a_i \leq 10^9$

输入样例
2
2
1 2
4
1 2 2 3
输出样例
Case #1: 2
Case #2: 1
题解
思路

首先我们可以考虑题目所给的条件。显然,一个节点为其所有子树的平均数,可以被看作是其所有子树节点的和为当前节点的值的 $k$ 倍,即 $sum_x = k \times w_x$,其中 $sum_x$ 代表节点 $x$ 的所有后代节点之和,$w_x$ 代表节点 $x$ 的权值。

由此我们可以想到,如果一棵二叉树中,每个节点的所有子树节点的权值之和都相等,那么这棵二叉树就满足题目所述条件,并且可以称之为那 $k$ 的一种取值。

那么问题就归结为判断是否存在一个取值 $k$,让一棵完全二叉树中每个节点都能够得到权值,且其中任意一个节点的权值都是其所有子树节点权值的平均数。

为了方便处理,可以先将输入的数据(即所有的 $a_i$)从小到大排序。这样我们可以保证,对于任意的 $i$,所有小于 $a_i$ 的节点必定在节点 $i$ 的某一个子树中;所有大于 $a_i$ 的节点必定不在节点 $i$ 的任何子树中。

回到原问题,通过观察,我们可以发现此题问题的两大复杂之处:

  1. 我们无法引入某些具体的数据结构来完成求解。尤其是对于二叉树,它本身就非常抽象,让我们更难想象如何引入特定的数据结构帮助处理。

  2. 我们无法枚举 $k$ 的所有值,然后依次检查是否有符合条件的完全二叉树。因为显然枚举 $k$ 的复杂度是 $O(n)$ 的,而根据主定理,单次判断一个可能的解,其复杂度也为 $O(n)$。这样一来,我们就需要 $O(n ^ 2)$ 的时间复杂度来完成这道题,这显然是无法接受的。

那么可以思考一种更加直观的解法吗?

解法

接下来,我们考虑问题的一种更加直观的思路。如果我们已经计算出了一棵完全二叉树中某个节点的权值 $a$,有没有一种快速的方法,能够在已知 $a$ 的条件下,求出其余节点权值的结果?

显然,对于一个节点 $x$,根据公式 $a_x = \frac {sum_x} {n_x}$,我们可以得到:

$$ \frac {n_x \times a_x}{k} = \sum _{y\in son_x}a_y $$

其中 $son_x$ 表示节点 $x$ 的左右儿子,$n_x = |son_x|$ 表示节点 $x$ 所有子树的节点个数之和。

由此我们可以得到一个类似于DP的递推式:从叶子节点开始,逐次计算当前节点所代表的子树内所有节点之和,并借此计算出它的父节点的值。

但是,我们依然不知道 $k$。考虑如何实现对 $k$ 的求解。对于任意一个节点 $x$,根据公式 $sum_x = k \times w_x$,我们可以推导得到:

$$ \frac{\sum_{y \in subtree(x)} w_y}{k} = n $$

其中 $subtree(x)$ 表示节点 $x$ 所有后代节点的集合。

由此,我们可以发现一个好的性质:当我们固定 $k$ 时,这个二叉树的性质就会变成一个关于二叉树中节点权值的方程组。我们只需要通过解方程组的方法,找到符合条件的节点 $a_i$,就可以得到一个解了。

至于如何解方程组,这里可以采用任何一种线性方程求解法,本文采用较为简单的高斯消元。

注意一棵树不一定是完全二叉树,所以我们需要对树进行平衡化处理,变成一棵完全二叉树。这可以通过对节点编号(逐层从左往右,从上往下)进行重新排列,来实现。

代码
import java.util.*;

public class Main {
    public static void main(String[] args) {
        Scanner scanner = new Scanner(System.in);
        int T = scanner.nextInt();
        for (int caseIndex = 1; caseIndex <= T; caseIndex++) {
            int n = scanner.nextInt();
            int[] a = new int[n];
            for (int i = 0; i < n; i++) {
                a[i] = scanner.nextInt();
            }
            Arrays.sort(a);
            for (int i = 0; i < n; i++) {
                a[i] -= a[0];
            }
            Map<Integer, Integer> id = new HashMap<>();
            Queue<Integer> queue = new LinkedList<>();
            queue.add(0); id.put(0, 0);
            int currentIndex = 1;
            while (!queue.isEmpty()) {
                int curNodeIndex = queue.poll();
                if (curNodeIndex * 2 + 1 < n && a[curNodeIndex * 2 + 1] != a[curNodeIndex]) {
                    id.put(curNodeIndex * 2 + 1, currentIndex++);
                    queue.add(curNodeIndex * 2 + 1);
                }
                if (curNodeIndex * 2 + 2 < n && a[curNodeIndex * 2 + 2] != a[curNodeIndex]) {
                    id.put(curNodeIndex * 2 + 2, currentIndex++);
                    queue.add(curNodeIndex * 2 + 2);
                }
            }
            double[] weight = new double[currentIndex];
            for (Map.Entry<Integer, Integer> entry : id.entrySet()) {
                weight[entry.getValue()] = (double) a[entry.getKey()] / n;
            }
            double[][] linearEquations = new double[currentIndex][currentIndex + 1];
            for (int i = 0; i < currentIndex; i++) {
                int nCount = getChildCount(a, i, n);
                for (int j = 0; j < currentIndex; j++) {
                    if (j == i) {
                        linearEquations[i][j] = nCount;
                    } else if (j == (i - 1) / 2) {
                        linearEquations[i][j] = -1 * nCount;
                    } else if (j == currentIndex) {
                        linearEquations[i][j] = weight[i] * nCount;
                    }
                }
            }
            gaussianElimination(linearEquations, currentIndex);
            System.out.println(String.format("Case #%d: %d", caseIndex, getResult(linearEquations, currentIndex)));
        }
    }

    private static int getChildCount(int[] a, int i, int n) {
        int ret = 1;
        int lChildIndex = 2 * i + 1;
        int rChildIndex = 2 * i + 2;
        if (lChildIndex < n && a[lChildIndex] != a[i]) {
            ret += getChildCount(a, lChildIndex, n);
        }
        if (rChildIndex < n && a[rChildIndex] != a[i]) {
            ret += getChildCount(a, rChildIndex, n);
        }
        return ret;
    }

    private static int getResult(double[][] linearEquations, int n) {
        double max = Integer.MIN_VALUE;
        int maxIndex = 0;
        for (int i = 0; i < n; i++) {
            if (linearEquations[i][n] / linearEquations[i][i] > max) {
                max = linearEquations[i][n] / linearEquations[i][i];
                maxIndex = i;
            }
        }
        return (int) (linearEquations[maxIndex][n] / linearEquations[maxIndex][maxIndex]);
    }

    private static void gaussianElimination(double[][] matrix, int n) {
        for (int col = 0, row = 0; col < n && row < n; col++, row++) {
            int maxRowIndex = row;
            for (int i = row; i < n; i++) {
                if (Math.abs(matrix[i][col]) > Math.abs(matrix[maxRowIndex][col])) {
                    maxRowIndex = i;
                }
            }
            if (Double.valueOf(0).compareTo(matrix[maxRowIndex][col]) == 0) {
                col--;
                continue;
            }
            if (maxRowIndex != row) {
                for (int i = col; i < n + 1; i++) {
                    double tmp = matrix[row][i];
                    matrix[row][i] = matrix[maxRowIndex][i];
                    matrix[maxRowIndex][i] = tmp;
                }
            }
            for (int i = col + 1; i < n + 1; i++) {
                matrix[row][i] /= matrix[row][col];
            }
            matrix[row][col] = 1;
            for (int i = 0; i < n; i++) {
                if (i == row) continue;
                for (int j = col + 1; j < n + 1; j++) {
                    matrix[i][j] -= matrix[row][j] * matrix[i][col];
                }
                matrix[i][col] = 0;
            }
        }
    }
}