1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54
| class Solution { public int[] sumOfDistancesInTree(int n, int[][] edges) { List<List<Integer>> nodeToNodesList = new ArrayList<>(); for (int i = 0; i < n; i++) { nodeToNodesList.add(new ArrayList<>()); } for (int[] edge : edges) { int from = edge[0], to = edge[1]; nodeToNodesList.get(from).add(to); nodeToNodesList.get(to).add(from); }
int[] size = new int[n]; int[] answer = new int[n];
dfs(answer, size, 0, -1, 0, nodeToNodesList);
distanceDfs(answer, size, 0, -1, nodeToNodesList);
return answer; }
private void distanceDfs(int[] answer, int[] size, int node, int parent, List<List<Integer>> nodeToNodesList) { if (parent != -1) { answer[node] = answer[parent] + (answer.length - size[node]) - size[node]; }
for (int nextNode : nodeToNodesList.get(node)) { if (nextNode != parent) { distanceDfs(answer, size, nextNode, node, nodeToNodesList); } } }
private int dfs(int[] answer, int[] size, int node, int parent, int depth, List<List<Integer>> nodeToNodesList) { answer[0] += depth; int nodes = 1;
for (int nextNode : nodeToNodesList.get(node)) { if (nextNode != parent) { nodes += dfs(answer, size, nextNode, node, depth + 1, nodeToNodesList); } }
size[node] = nodes; return nodes; } }
|