Distance in Tree

Distance in Tree

官方题解(英文)

找树上有多少对节点,之间路径距离是K (相邻节点距离为1)。

这个用树上DP来解,定义dp[i][j][i][j]如下:

所以题目答案就是dp[0][k][0][k]

dp[i][0..k1][i][0..k-1]的计算是比较直观的,DFS就完成了。

dp[i][k][i][k]麻烦一些,是下面三种情况的和:

其中1/21/2是因为路径的左右端点会被重复计算一次,减去dp[u][kx1]\mathtt{dp}[u][k-x-1]是出发的子树和结束的子树不能是同一个,而前面有多计算的部分。

#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
 
int n, K; 
vector<int> adj[50005];
ll dp[50005][505];
 
void dfs(int s, int e) {
    dp[s][0] = 1;
    int tot[505] = {};
    for (int u: adj[s]) {
        if (u == e) continue;
        dfs(u, s);
        for (int i = 0; i <= K; i++)
            tot[i] += dp[u][i];
    }
    for (int i = 1; i <= K; i++)
        dp[s][i] = tot[i-1];
    ll sum = 0;
    for (int u: adj[s]) {
        if (u == e) continue;
        dp[s][K] += dp[u][K];
        for (int x = 1; x <= K-1; x++)    // calc dp[s][K] through s, 见解答
            sum += dp[u][x-1] * (dp[s][K-x] - dp[u][K-x-1]);
    }
    dp[s][K] += sum/2;
}
 
int main() {
    scanf("%d %d", &n, &K);
    for (int i = 0; i < n-1; i++) {
        int u,v;
        scanf("%d %d", &u, &v);
        u--,v--;
        adj[u].push_back(v);
        adj[v].push_back(u);
    }
    dfs(0, -1);
    printf("%lld", dp[0][K]);
 
    return 0;
}