AI Translated from Chinese

Problem

Problem Statement

Given a tree, for each node, consider the maximum value d such that after removing d edges, the node can still reach at least one other node. Compute the sum of these d values across all nodes.

Input

The first line contains an integer n (1 ≤ n ≤ 2×10^5), the number of nodes in the tree.

The next n-1 lines contain edges of the tree.

Output

A single integer — the sum of the required values for all nodes.

Solution

For each node, we need to find the distance to the farthest node in the tree (the tree’s diameter) and compute the sum of distances from each node to the farthest node.

The solution can be found using two BFS/DFS passes:

  1. Find one endpoint of the tree’s diameter
  2. From that endpoint, find the other endpoint and compute distances
  3. From the other endpoint, compute distances again
  4. For each node, the answer is max(dist1[node], dist2[node])

Code

#include <bits/stdc++.h>
using namespace std;

int main() {
    ios::sync_with_stdio(false);
    cin.tie(nullptr);

    int n;
    if(!(cin >> n)) return 0;
    vector<vector<int>> g(n+1);
    for(int i=0;i<n-1;i++){
        int u,v;cin>>u>>v;
        g[u].push_back(v);
        g[v].push_back(u);
    }

    auto bfs = [&](int src){
        vector<int> dist(n+1, -1);
        queue<int> q;
        q.push(src);
        dist[src]=0;
        while(!q.empty()){
            int u=q.front();q.pop();
            for(int v:g[u]){
                if(dist[v]==-1){
                    dist[v]=dist[u]+1;
                    q.push(v);
                }
            }
        }
        return dist;
    };

    // First BFS to find one end of diameter
    vector<int> d1 = bfs(1);
    int end1 = 1;
    for(int i=1;i<=n;i++) if(d1[i]>d1[end1]) end1=i;

    // Second BFS from end1
    vector<int> d2 = bfs(end1);
    int end2 = end1;
    for(int i=1;i<=n;i++) if(d2[i]>d2[end2]) end2=i;

    // Third BFS from end2
    vector<int> d3 = bfs(end2);

    long long ans = 0;
    for(int i=1;i<=n;i++) ans += max(d2[i], d3[i]);
    cout << ans << "\n";
    return 0;
}

Explanation

The key insight is that for any node in a tree, the farthest node from it must be one of the two endpoints of the tree’s diameter. Therefore, the distance to the farthest node for each node is simply the maximum of its distances to the two diameter endpoints.

The sum of these maximum distances across all nodes gives the answer.