Sum of Distances in Tree

Sanjeev SharmaSanjeev Sharma
3 min read

Advertisement

Problem

Given an undirected tree, return an array where answer[i] = sum of distances from node i to all other nodes.

Key insight: Two-pass DFS (rerooting technique):

  1. DFS1 from root: compute count[node] (subtree size) and dist[root] (sum of distances from root)
  2. DFS2 rerooting: dist[child] = dist[parent] - count[child] + (n - count[child])

Solutions

// C++
vector<int> sumOfDistancesInTree(int n, vector<vector<int>>& edges) {
    vector<vector<int>> adj(n);
    for (auto& e : edges) { adj[e[0]].push_back(e[1]); adj[e[1]].push_back(e[0]); }
    vector<int> cnt(n, 1), dist(n, 0);
    function<void(int,int)> dfs1 = [&](int u, int p) {
        for (int v : adj[u]) if (v != p) {
            dfs1(v, u);
            cnt[u] += cnt[v];
            dist[u] += dist[v] + cnt[v];
        }
    };
    function<void(int,int)> dfs2 = [&](int u, int p) {
        for (int v : adj[u]) if (v != p) {
            dist[v] = dist[u] - cnt[v] + (n - cnt[v]);
            dfs2(v, u);
        }
    };
    dfs1(0, -1);
    dfs2(0, -1);
    return dist;
}
// Java
public int[] sumOfDistancesInTree(int n, int[][] edges) {
    List<List<Integer>> adj = new ArrayList<>();
    for (int i = 0; i < n; i++) adj.add(new ArrayList<>());
    for (int[] e : edges) { adj.get(e[0]).add(e[1]); adj.get(e[1]).add(e[0]); }
    int[] cnt = new int[n], dist = new int[n];
    Arrays.fill(cnt, 1);
    dfs1(0, -1, adj, cnt, dist);
    dfs2(0, -1, n, adj, cnt, dist);
    return dist;
}
void dfs1(int u, int p, List<List<Integer>> adj, int[] cnt, int[] dist) {
    for (int v : adj.get(u)) if (v != p) {
        dfs1(v, u, adj, cnt, dist);
        cnt[u] += cnt[v]; dist[u] += dist[v] + cnt[v];
    }
}
void dfs2(int u, int p, int n, List<List<Integer>> adj, int[] cnt, int[] dist) {
    for (int v : adj.get(u)) if (v != p) {
        dist[v] = dist[u] - cnt[v] + (n - cnt[v]);
        dfs2(v, u, n, adj, cnt, dist);
    }
}
// JavaScript
function sumOfDistancesInTree(n, edges) {
    const adj = Array.from({length: n}, () => []);
    for (const [u, v] of edges) { adj[u].push(v); adj[v].push(u); }
    const cnt = Array(n).fill(1), dist = Array(n).fill(0);
    function dfs1(u, p) {
        for (const v of adj[u]) if (v !== p) {
            dfs1(v, u);
            cnt[u] += cnt[v]; dist[u] += dist[v] + cnt[v];
        }
    }
    function dfs2(u, p) {
        for (const v of adj[u]) if (v !== p) {
            dist[v] = dist[u] - cnt[v] + (n - cnt[v]);
            dfs2(v, u);
        }
    }
    dfs1(0, -1); dfs2(0, -1);
    return dist;
}
# Python
from collections import defaultdict

def sumOfDistancesInTree(n, edges):
    adj = defaultdict(list)
    for u, v in edges:
        adj[u].append(v)
        adj[v].append(u)
    cnt = [1] * n
    dist = [0] * n

    def dfs1(u, p):
        for v in adj[u]:
            if v != p:
                dfs1(v, u)
                cnt[u] += cnt[v]
                dist[u] += dist[v] + cnt[v]

    def dfs2(u, p):
        for v in adj[u]:
            if v != p:
                dist[v] = dist[u] - cnt[v] + (n - cnt[v])
                dfs2(v, u)

    dfs1(0, -1)
    dfs2(0, -1)
    return dist

Complexity

  • Time: O(n)
  • Space: O(n)

Key Insight

When moving root from parent u to child v: nodes in v's subtree get 1 closer (cnt[v] nodes), others get 1 farther (n - cnt[v] nodes). Net change = (n - cnt[v]) - cnt[v] = n - 2*cnt[v].

Advertisement

Sanjeev Sharma

Written by

Sanjeev Sharma

Full Stack Engineer · E-mopro