Minimum Time to Collect All Apples in a Tree

Sanjeev SharmaSanjeev Sharma
2 min read

Advertisement

Problem

Given an undirected tree with apples on some nodes, return minimum time to collect all apples starting and ending at node 0. Each edge takes 1 second each way.

Key insight: DFS from root. Include a child's subtree only if it has apples. Cost = 2 (round trip) + child's subtree cost.

Solutions

// C++
int minTime(int n, vector<vector<int>>& edges, vector<bool>& hasApple) {
    vector<vector<int>> adj(n);
    for (auto& e : edges) { adj[e[0]].push_back(e[1]); adj[e[1]].push_back(e[0]); }
    function<int(int,int)> dfs = [&](int node, int parent) -> int {
        int total = 0;
        for (int child : adj[node]) {
            if (child == parent) continue;
            int childCost = dfs(child, node);
            if (childCost > 0 || hasApple[child])
                total += childCost + 2;
        }
        return total;
    };
    return dfs(0, -1);
}
// Java
public int minTime(int n, int[][] edges, List<Boolean> hasApple) {
    List<List<Integer>> adj = new ArrayList<>();
    for (int i = 0; i < n; i++) adj.add(new ArrayList<>());
    for (int[] e : edges) { adj.get(e[0]).add(e[1]); adj.get(e[1]).add(e[0]); }
    return dfs(0, -1, adj, hasApple);
}
int dfs(int node, int parent, List<List<Integer>> adj, List<Boolean> hasApple) {
    int total = 0;
    for (int child : adj.get(node)) {
        if (child == parent) continue;
        int childCost = dfs(child, node, adj, hasApple);
        if (childCost > 0 || hasApple.get(child)) total += childCost + 2;
    }
    return total;
}
// JavaScript
function minTime(n, edges, hasApple) {
    const adj = Array.from({length: n}, () => []);
    for (const [u, v] of edges) { adj[u].push(v); adj[v].push(u); }
    function dfs(node, parent) {
        let total = 0;
        for (const child of adj[node]) {
            if (child === parent) continue;
            const childCost = dfs(child, node);
            if (childCost > 0 || hasApple[child]) total += childCost + 2;
        }
        return total;
    }
    return dfs(0, -1);
}
# Python
from collections import defaultdict

def minTime(n, edges, hasApple):
    adj = defaultdict(list)
    for u, v in edges:
        adj[u].append(v)
        adj[v].append(u)

    def dfs(node, parent):
        total = 0
        for child in adj[node]:
            if child == parent:
                continue
            child_cost = dfs(child, node)
            if child_cost > 0 or hasApple[child]:
                total += child_cost + 2
        return total

    return dfs(0, -1)

Complexity

  • Time: O(n)
  • Space: O(n)

Advertisement

Sanjeev Sharma

Written by

Sanjeev Sharma

Full Stack Engineer · E-mopro