Distance in Tree
找树上有多少对节点,之间路径距离是K (相邻节点距离为1)。
这个用树上DP来解,定义dp如下:
- , 子树中,从处出发的路径,长度为的条数
- ,子树中,长度为的路径的条数
所以题目答案就是dp。
dp的计算是比较直观的,DFS就完成了。
dp麻烦一些,是下面三种情况的和:
- 每个儿子的子树中长路径的个数,即对每个儿子,dp
- 从出发的长路径,总数有dp
- 以为中间节点(非端点)的长路径个数:
其中是因为路径的左右端点会被重复计算一次,减去是出发的子树和结束的子树不能是同一个,而前面有多计算的部分。
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
int n, K;
vector<int> adj[50005];
ll dp[50005][505];
void dfs(int s, int e) {
dp[s][0] = 1;
int tot[505] = {};
for (int u: adj[s]) {
if (u == e) continue;
dfs(u, s);
for (int i = 0; i <= K; i++)
tot[i] += dp[u][i];
}
for (int i = 1; i <= K; i++)
dp[s][i] = tot[i-1];
ll sum = 0;
for (int u: adj[s]) {
if (u == e) continue;
dp[s][K] += dp[u][K];
for (int x = 1; x <= K-1; x++) // calc dp[s][K] through s, 见解答
sum += dp[u][x-1] * (dp[s][K-x] - dp[u][K-x-1]);
}
dp[s][K] += sum/2;
}
int main() {
scanf("%d %d", &n, &K);
for (int i = 0; i < n-1; i++) {
int u,v;
scanf("%d %d", &u, &v);
u--,v--;
adj[u].push_back(v);
adj[v].push_back(u);
}
dfs(0, -1);
printf("%lld", dp[0][K]);
return 0;
}