从根节点到树中每个其他节点的 MEX

MEX from root node to every other node in a tree

给定一个有 N 个节点的有根树。根节点是节点 1。每个第 i 个节点都有一些值 val[i] 与之关联。

对于每个节点 i (1<=i<=N),我们想知道从根节点到节点 i 的路径值的 MEX。

数组的MEX是数组中不存在的最小正整数,例如{1,2,4}的MEX是3

示例: 假设我们有一个有 4 个节点的树。节点的值为 [1,3,2,8],我们还有每个节点 i 的父节点(节点 1 除外,因为它是根节点)。本例中父数组定义为 [1,2,2]。表示节点2的父节点是节点1,节点3的父节点是节点2,节点4的父节点也是节点2。

Node 1 : MEX(1) = 2
Node 2 : MEX(1,3) = 2
Node 3 : MEX(1,3,2) = 4
Node 4 : MEX(1,3,8) = 2

因此答案是 [2,2,4,2]

在最坏的情况下,节点总数可达 10^6,每个节点的值可达 10^9。

尝试:

方法 1:我们知道 N 个元素的 MEX 总是在 1 到 N+1 之间。我试图将这种理解用于这个树问题,但在这种情况下,N 将随着叶节点的前进而不断动态变化。

方法 2:另一个想法是创建一个包含 N+1 个空值的数组,然后尝试从根节点开始填充它们。但后来我面临的挑战是跟踪此数组中的第一个未填充值。

这可以及时完成 O(n log n) 使用增强 BST。

假设您有一个支持以下操作的数据结构:

  1. insert(x),添加数字 x 的副本。
  2. remove(x),删除数字 x 的副本。
  3. mex(),returns 集合的 MEX。

有了类似的东西,您可以通过执行递归树遍历、在开始访问节点时插入项目并在离开节点时删除这些项目来轻松解决问题。这将对这些函数中的每一个进行 n 次调用,因此目标是最小化它们的成本。

我们可以使用增强 BST 来做到这一点。现在,假设原始树中的所有数字都是不同的;稍后我们会处理重复的情况。从您选择的 BST 开始,并通过让每个节点存储其左子树中的节点数来扩充它。这可以在不改变插入或删除的渐近成本的情况下完成(如果你以前没有见过这个,请查看 order statistic tree 数据结构)。然后,您可以按如下方式找到 MEX。从根开始,查看其值及其左子树中的节点数。将发生以下情况之一:

  1. 节点的值k恰好是左子树节点数的1。这意味着所有值 1、2、3、...、k 都在树中,因此 MEX 将是右子树中缺失的最小值。递归地找到右子树的 MEX。当你这样做时,请记住你已经看到了从 1 到 k 的值,方法是从你遇到的所有值中减去 k。
  2. 节点的k值至少比左子树的节点数多2。这意味着在节点的左子树加上根节点的某处存在间隙。递归求左子树的MEX。

离开树后,您可以查看最后一个正确的节点并向其添加一个以获得 MEX。 (如果你从来没有走对,MEX 就是 1)。

这是对每个节点执行 O(1) 工作的平衡树的自上而下传递,因此总共需要 O(log n) 工作。

唯一复杂的是,如果原始树(不是扩充 BST)中的值在路径上重复,会发生什么情况。但这很容易修复:只需向每个 BST 节点添加一个计数字段来跟踪它出现的次数,在发生插入时递增它,在发生删除时递减它。然后,只有在频率降为零的情况下,才将节点从BST中移除。

总的来说,对这样一棵树的每个操作都需要时间 O(log n),因此这为您的原始问题提供了 O(n log n) 时间算法。

public class TestClass {
    public static void main(String[] args) throws IOException {
        BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
        PrintWriter wr = new PrintWriter(System.out);
        int T = Integer.parseInt(br.readLine().trim());
        for(int t_i = 0; t_i < T; t_i++)
        {
            int N = Integer.parseInt(br.readLine().trim());
            String[] arr_val = br.readLine().split(" ");
            int[] val = new int[N];
            for(int i_val = 0; i_val < arr_val.length; i_val++)
            {
                val[i_val] = Integer.parseInt(arr_val[i_val]);
            }
            String[] arr_parent = br.readLine().split(" ");
            int[] parent = new int[N-1];
            for(int i_parent = 0; i_parent < arr_parent.length; i_parent++)
            {
                parent[i_parent] = Integer.parseInt(arr_parent[i_parent]);
            }

            int[] out_ = solve(N, val, parent);
            System.out.print(out_[0]);
            for(int i_out_ = 1; i_out_ < out_.length; i_out_++)
            {
                System.out.print(" " + out_[i_out_]);
            }
            
            System.out.println();
            
         }

         wr.close();
         br.close();
    }
    static int[] solve(int N, int[] val, int[] parent){
       // Write your code here
        int[] result = new int[val.length];
        ArrayList<ArrayList<Integer>> temp = new ArrayList<>();
        ArrayList<Integer> curr = new ArrayList<>();

        if(val[0]==1)
            curr.add(2);
        else{
            curr.add(1);
            curr.add(val[0]);
        }
        result[0]=curr.get(0);
        temp.add(new ArrayList<>(curr));
        for(int i=1;i<val.length;i++){

            int parentIndex = parent[i-1]-1;
            curr = new ArrayList<>(temp.get(parentIndex));
            int nodeValue = val[i];
            boolean enter = false;
            while(curr.size()>0 && nodeValue == curr.get(0)){
                curr.remove(0);
                nodeValue++;
                enter=true;
            }
            if(curr.isEmpty())
                curr.add(nodeValue);
            else if(!curr.isEmpty() && curr.contains(nodeValue) ==false && (enter|| curr.get(0)<nodeValue))
                curr.add(nodeValue);

            Collections.sort(curr);
            temp.add(new ArrayList<>(curr));
            result[i]=curr.get(0);
        }

        return result;
    
    }
}
public class PathMex {
    static void dfs(int node, int mexVal, int[] res, int[] values, ArrayList<ArrayList<Integer>> adj, HashMap<Integer, Integer> map) {
        if (!map.containsKey(values[node])) {
            map.put(values[node], 1);
        }
        else {
            map.put(values[node], map.get(values[node]) + 1);
        }
        while(map.containsKey(mexVal)) mexVal++;
        res[node] = mexVal;

        ArrayList<Integer> children = adj.get(node);
        for (Integer child : children) {
            dfs(child, mexVal, res, values, adj, map);
        }

        if (map.containsKey(values[node])) {
            if (map.get(values[node]) == 1) {
                map.remove(values[node]);
            }
            else {
                map.put(values[node], map.get(values[node]) - 1);
            }
        }
    }

    static int[] findPathMex(int nodes, int[] values, int[] parent) {
        ArrayList<ArrayList<Integer>> adj = new ArrayList<>(nodes);
        HashMap<Integer, Integer> map = new HashMap<>();

        int[] res = new int[nodes];

        for (int i = 0; i < nodes; i++) {
            adj.add(new ArrayList<Integer>());
        }
        for (int i = 0; i < nodes - 1; i++) {
            adj.get(parent[i] - 1).add(i + 1);
        }
        dfs(0, 1, res, values, adj, map);
        return res;
    }

    public static void main(String args[]) {
        Scanner sc = new Scanner(System.in);
        int nodes = sc.nextInt();
        int[] values = new int[nodes];
        int[] parent = new int[nodes - 1];
        for (int i = 0; i < nodes; i++) {
            values[i] = sc.nextInt();
        }
        for (int i = 0; i < nodes - 1; i++) {
            parent[i] = sc.nextInt();
        }
        int[] res = findPathMex(nodes, values, parent);
        for (int i = 0; i < nodes; i++) {
            System.out.print(res[i] + " ");
        }
    }
}