题意:一棵n个点的树,给定m树上的路径,需要从这m条路径中选k条,需要满足至少有一个点被所有的k条路径经过,求选择的方案数 (n,m,k <= 6e5)
一个性质:两条或多条路径的交一定是一个联通块,且一定会有至少一个点为某个路径的lca
直接统计答案的话(对于每一个点算贡献),我们发现会重复,需要去重
因为上面的那个性质,对于一个点,我们只统计至少有一个是lca的边的集合就好了
#include <iostream>
#include <cstdio>
#include <cstring>
#include <algorithm>
#define gi get_int()
#define int long long
const int MAXN = 6e5, MOD = 998244353;
int get_int()
{
int x = 0, y = 1;
char ch = getchar();
while (!isdigit(ch) && ch != ‘-‘)
ch = getchar();
if (ch == ‘-‘)
y = -1, ch = getchar();
while (isdigit(ch))
x = x * 10 + ch - ‘0‘, ch = getchar();
return x * y;
}
int up[MAXN][21], num2[MAXN], fac[MAXN], deep[MAXN], count[MAXN];
class Edge
{
public:
int next, to;
} edges[MAXN * 2];
int head[MAXN], eNum;
void addEdge(int from, int to)
{
edges[eNum] = (Edge) {head[from], to};
head[from] = eNum++;
}
int dfs(int now = 0, int pre = -1)
{
for (int i = head[now]; i != -1; i = edges[i].next) {
int to = edges[i].to;
if (to == pre) continue;
up[to][0] = now;
deep[to] = deep[now] + 1;
dfs(to, now);
}
}
int LCA(int u, int v)
{
if (deep[u] < deep[v])
std::swap(u, v);
for (int i = 20; i >= 0; i--) {
if (deep[up[u][i]] >= deep[v])
u = up[u][i];
}
for (int i = 20; i >= 0; i--) {
if (up[u][i] != up[v][i]) {
u = up[u][i];
v = up[v][i];
}
}
return u == v ? u : up[u][0];
}
void dfs1(int now = 0, int pre = -1)
{
for (int i = head[now]; i != -1; i = edges[i].next) {
int to = edges[i].to;
if (to == pre) continue;
dfs1(to, now);
count[now] += count[to];
}
}
int qPow(int x, int y)
{
int ans = 1;
while (y != 0) {
if (y & 1)
(ans *= x) %= MOD;
x *= x;
x %= MOD;
y >>= 1;
}
return ans;
}
int C(int x, int y)
{
if (x < y) return 0;
int ans = fac[x] * qPow((fac[x - y] * fac[y]) % MOD, MOD - 2);
return ans % MOD;
}
signed main()
{
memset(head, -1, sizeof(head));
int n = gi, m = gi, k = gi;
for (int i = 1; i < n; i++) {
int from = gi - 1, to = gi - 1;
addEdge(from, to);
addEdge(to, from);
}
dfs();
for (int j = 1; j < 21; j++)
for (int i = 0; i < n; i++)
up[i][j] = up[up[i][j - 1]][j - 1];
for (int i = 0; i < m; i++) {
int qX = gi - 1, qY = gi - 1;
int lca = LCA(qX, qY);
count[qX]++;
count[qY]++;
if (lca == 0)
count[lca]--;
else
count[lca]--, count[up[lca][0]]--;
num2[lca]++;
}
dfs1();
fac[0] = 1;
for (int i = 1; i <= MAXN; i++) {
fac[i] = (fac[i - 1] * i) % MOD;
}
int ans = 0;
for (int i = 0; i < n; i++) {
(ans += (C(count[i], k) - C(count[i] - num2[i], k) + MOD) % MOD) %= MOD;
}
std::cout << ans;
return 0;
}
原文:https://www.cnblogs.com/enisP/p/14825632.html