Sum of Distances in Tree
Advertisement
Problem
Given an undirected tree, return an array where answer[i] = sum of distances from node i to all other nodes.
Key insight: Two-pass DFS (rerooting technique):
- DFS1 from root: compute count[node] (subtree size) and dist[root] (sum of distances from root)
- DFS2 rerooting:
dist[child] = dist[parent] - count[child] + (n - count[child])
Solutions
// C++
vector<int> sumOfDistancesInTree(int n, vector<vector<int>>& edges) {
vector<vector<int>> adj(n);
for (auto& e : edges) { adj[e[0]].push_back(e[1]); adj[e[1]].push_back(e[0]); }
vector<int> cnt(n, 1), dist(n, 0);
function<void(int,int)> dfs1 = [&](int u, int p) {
for (int v : adj[u]) if (v != p) {
dfs1(v, u);
cnt[u] += cnt[v];
dist[u] += dist[v] + cnt[v];
}
};
function<void(int,int)> dfs2 = [&](int u, int p) {
for (int v : adj[u]) if (v != p) {
dist[v] = dist[u] - cnt[v] + (n - cnt[v]);
dfs2(v, u);
}
};
dfs1(0, -1);
dfs2(0, -1);
return dist;
}
// Java
public int[] sumOfDistancesInTree(int n, int[][] edges) {
List<List<Integer>> adj = new ArrayList<>();
for (int i = 0; i < n; i++) adj.add(new ArrayList<>());
for (int[] e : edges) { adj.get(e[0]).add(e[1]); adj.get(e[1]).add(e[0]); }
int[] cnt = new int[n], dist = new int[n];
Arrays.fill(cnt, 1);
dfs1(0, -1, adj, cnt, dist);
dfs2(0, -1, n, adj, cnt, dist);
return dist;
}
void dfs1(int u, int p, List<List<Integer>> adj, int[] cnt, int[] dist) {
for (int v : adj.get(u)) if (v != p) {
dfs1(v, u, adj, cnt, dist);
cnt[u] += cnt[v]; dist[u] += dist[v] + cnt[v];
}
}
void dfs2(int u, int p, int n, List<List<Integer>> adj, int[] cnt, int[] dist) {
for (int v : adj.get(u)) if (v != p) {
dist[v] = dist[u] - cnt[v] + (n - cnt[v]);
dfs2(v, u, n, adj, cnt, dist);
}
}
// JavaScript
function sumOfDistancesInTree(n, edges) {
const adj = Array.from({length: n}, () => []);
for (const [u, v] of edges) { adj[u].push(v); adj[v].push(u); }
const cnt = Array(n).fill(1), dist = Array(n).fill(0);
function dfs1(u, p) {
for (const v of adj[u]) if (v !== p) {
dfs1(v, u);
cnt[u] += cnt[v]; dist[u] += dist[v] + cnt[v];
}
}
function dfs2(u, p) {
for (const v of adj[u]) if (v !== p) {
dist[v] = dist[u] - cnt[v] + (n - cnt[v]);
dfs2(v, u);
}
}
dfs1(0, -1); dfs2(0, -1);
return dist;
}
# Python
from collections import defaultdict
def sumOfDistancesInTree(n, edges):
adj = defaultdict(list)
for u, v in edges:
adj[u].append(v)
adj[v].append(u)
cnt = [1] * n
dist = [0] * n
def dfs1(u, p):
for v in adj[u]:
if v != p:
dfs1(v, u)
cnt[u] += cnt[v]
dist[u] += dist[v] + cnt[v]
def dfs2(u, p):
for v in adj[u]:
if v != p:
dist[v] = dist[u] - cnt[v] + (n - cnt[v])
dfs2(v, u)
dfs1(0, -1)
dfs2(0, -1)
return dist
Complexity
- Time: O(n)
- Space: O(n)
Key Insight
When moving root from parent u to child v: nodes in v's subtree get 1 closer (cnt[v] nodes), others get 1 farther (n - cnt[v] nodes). Net change = (n - cnt[v]) - cnt[v] = n - 2*cnt[v].
Advertisement
← Previous
Unique Binary Search Trees