LeetCode in Kotlin

3515. Shortest Path in a Weighted Tree

Hard

You are given an integer n and an undirected, weighted tree rooted at node 1 with n nodes numbered from 1 to n. This is represented by a 2D array edges of length n - 1, where edges[i] = [ui, vi, wi] indicates an undirected edge from node ui to vi with weight wi.

You are also given a 2D integer array queries of length q, where each queries[i] is either:

Return an integer array answer, where answer[i] is the shortest path distance from node 1 to x for the ith query of [2, x].

Example 1:

Input: n = 2, edges = [[1,2,7]], queries = [[2,2],[1,1,2,4],[2,2]]

Output: [7,4]

Explanation:

Example 2:

Input: n = 3, edges = [[1,2,2],[1,3,4]], queries = [[2,1],[2,3],[1,1,3,7],[2,2],[2,3]]

Output: [0,4,2,7]

Explanation:

Example 3:

Input: n = 4, edges = [[1,2,2],[2,3,1],[3,4,5]], queries = [[2,4],[2,3],[1,2,3,3],[2,2],[2,3]]

Output: [8,3,2,5]

Explanation:

Constraints:

Solution

class Solution {
    private lateinit var `in`: IntArray
    private lateinit var out: IntArray
    private lateinit var baseDist: IntArray
    private lateinit var parent: IntArray
    private lateinit var depth: IntArray
    private var timer = 0
    private lateinit var edgeWeight: IntArray
    private lateinit var adj: Array<MutableList<IntArray>>

    fun treeQueries(n: Int, edges: Array<IntArray>, queries: Array<IntArray>): IntArray {
        adj = Array<MutableList<IntArray>>(n + 1) { ArrayList<IntArray>() }
        for (e in edges) {
            val u = e[0]
            val v = e[1]
            val w = e[2]
            adj[u].add(intArrayOf(v, w))
            adj[v].add(intArrayOf(u, w))
        }
        `in` = IntArray(n + 1)
        out = IntArray(n + 1)
        baseDist = IntArray(n + 1)
        parent = IntArray(n + 1)
        depth = IntArray(n + 1)
        edgeWeight = IntArray(n + 1)
        dfs(1, 0, 0)
        val fenw = Fen(n)
        val ansList: MutableList<Int> = ArrayList<Int>()
        for (query in queries) {
            if (query[0] == 1) {
                val u = query[1]
                val v = query[2]
                val newW = query[3]
                val child: Int
                if (parent[v] == u) {
                    child = v
                } else if (parent[u] == v) {
                    child = u
                } else {
                    continue
                }
                val diff = newW - edgeWeight[child]
                edgeWeight[child] = newW
                fenw.updateRange(`in`[child], out[child], diff)
            } else {
                val x = query[1]
                val delta = fenw.query(`in`[x])
                ansList.add(baseDist[x] + delta)
            }
        }
        val answer = IntArray(ansList.size)
        for (i in ansList.indices) {
            answer[i] = ansList[i]
        }
        return answer
    }

    private fun dfs(node: Int, par: Int, dist: Int) {
        parent[node] = par
        baseDist[node] = dist
        depth[node] = if (par == 0) 0 else depth[par] + 1
        `in`[node] = ++timer
        for (neighborInfo in adj[node]) {
            val neighbor = neighborInfo[0]
            val w = neighborInfo[1]
            if (neighbor == par) {
                continue
            }
            edgeWeight[neighbor] = w
            dfs(neighbor, node, dist + w)
        }
        out[node] = timer
    }

    private class Fen(var n: Int) {
        var fenw: IntArray = IntArray(n + 2)

        fun update(i: Int, delta: Int) {
            var i = i
            while (i <= n) {
                fenw[i] += delta
                i += i and -i
            }
        }

        fun updateRange(l: Int, r: Int, delta: Int) {
            update(l, delta)
            update(r + 1, -delta)
        }

        fun query(i: Int): Int {
            var i = i
            var sum = 0
            while (i > 0) {
                sum += fenw[i]
                i -= i and -i
            }
            return sum
        }
    }
}