LeetCode in Kotlin

2867. Count Valid Paths in a Tree

Hard

There is an undirected tree with n nodes labeled from 1 to n. You are given the integer n and a 2D integer array edges of length n - 1, where edges[i] = [ui, vi] indicates that there is an edge between nodes ui and vi in the tree.

Return the number of valid paths in the tree.

A path (a, b) is valid if there exists exactly one prime number among the node labels in the path from a to b.

Note that:

Example 1:

Input: n = 5, edges = [[1,2],[1,3],[2,4],[2,5]]

Output: 4

Explanation: The pairs with exactly one prime number on the path between them are:

It can be shown that there are only 4 valid paths.

Example 2:

Input: n = 6, edges = [[1,2],[1,3],[2,4],[3,5],[3,6]]

Output: 6

Explanation: The pairs with exactly one prime number on the path between them are:

It can be shown that there are only 6 valid paths.

Constraints:

Solution

class Solution {
    private lateinit var isPrime: BooleanArray
    private lateinit var treeEdges: Array<MutableList<Int>?>
    private var r: Long = 0

    private fun preparePrime(n: Int): BooleanArray {
        // Sieve of Eratosthenes < 3
        val isPrimeLocal = BooleanArray(n + 1)
        for (i in 2 until n + 1) {
            isPrimeLocal[i] = true
        }
        for (i in 2..n / 2) {
            var j = 2 * i
            while (j < n + 1) {
                isPrimeLocal[j] = false
                j += i
            }
        }
        return isPrimeLocal
    }

    private fun prepareTree(n: Int, edges: Array<IntArray>): Array<MutableList<Int>?> {
        val treeEdgesLocal: Array<MutableList<Int>?> = arrayOfNulls(n + 1)
        for (edge in edges) {
            if (treeEdgesLocal[edge[0]] == null) {
                treeEdgesLocal[edge[0]] = ArrayList()
            }
            treeEdgesLocal[edge[0]]!!.add(edge[1])
            if (treeEdgesLocal[edge[1]] == null) {
                treeEdgesLocal[edge[1]] = ArrayList()
            }
            treeEdgesLocal[edge[1]]!!.add(edge[0])
        }
        return treeEdgesLocal
    }

    private fun countPathDfs(node: Int, parent: Int): LongArray {
        val v = longArrayOf((if (isPrime[node]) 0 else 1).toLong(), (if (isPrime[node]) 1 else 0).toLong())
        val edges = treeEdges[node] ?: return v
        for (neigh in edges) {
            if (neigh == parent) {
                continue
            }
            val ce = countPathDfs(neigh, node)
            r += v[0] * ce[1] + v[1] * ce[0]
            if (isPrime[node]) {
                v[1] += ce[0]
            } else {
                v[0] += ce[0]
                v[1] += ce[1]
            }
        }
        return v
    }

    fun countPaths(n: Int, edges: Array<IntArray>): Long {
        isPrime = preparePrime(n)
        treeEdges = prepareTree(n, edges)
        r = 0
        countPathDfs(1, 0)
        return r
    }
}