LeetCode in Kotlin

2458. Height of Binary Tree After Subtree Removal Queries

Hard

You are given the root of a binary tree with n nodes. Each node is assigned a unique value from 1 to n. You are also given an array queries of size m.

You have to perform m independent queries on the tree where in the ith query you do the following:

Return an array answer of size m where answer[i] is the height of the tree after performing the ith query.

Note:

Example 1:

Input: root = [1,3,4,2,null,6,5,null,null,null,null,null,7], queries = [4]

Output: [2]

Explanation: The diagram above shows the tree after removing the subtree rooted at node with value 4. The height of the tree is 2 (The path 1 -> 3 -> 2).

Example 2:

Input: root = [5,8,9,2,1,3,7,4,6], queries = [3,2,4,8]

Output: [3,2,3,2]

Explanation: We have the following queries:

Constraints:

Solution

import com_github_leetcode.TreeNode

class Solution {
    fun treeQueries(root: TreeNode?, queries: IntArray): IntArray {
        val levels: MutableMap<Int, IntArray> = HashMap()
        val map: MutableMap<Int, IntArray> = HashMap()
        val max = dfs(root, 0, map, levels) - 1
        val n = queries.size
        for (i in 0 until n) {
            val q = queries[i]
            val node = map[q]
            val height = node!![0]
            val level = node[1]
            val lev = levels[level]
            if (lev!![0] == height) {
                if (lev[1] != -1) {
                    queries[i] = max - Math.abs(lev[0] - lev[1])
                } else {
                    queries[i] = max - height - 1
                }
            } else {
                queries[i] = max
            }
        }
        return queries
    }

    private fun dfs(
        root: TreeNode?,
        level: Int,
        map: MutableMap<Int, IntArray>,
        levels: MutableMap<Int, IntArray>,
    ): Int {
        if (root == null) {
            return 0
        }
        val left = dfs(root.left, level + 1, map, levels)
        val right = dfs(root.right, level + 1, map, levels)
        val height = Math.max(left, right)
        val lev = levels.getOrDefault(level, intArrayOf(-1, -1))
        if (height >= lev[0]) {
            lev[1] = lev[0]
            lev[0] = height
        } else {
            lev[1] = Math.max(lev[1], height)
        }
        levels[level] = lev
        map[root.`val`] = intArrayOf(height, level)
        return Math.max(left, right) + 1
    }
}