看这道题目懵逼了好久, \(m <= 17\) 一眼容斥,然而并没有想到怎么求出生成树的个数。然后灵光一闪——我是不学过一个叫Prüfer编码的东西嘛?!那就完美解决啦~
Prüfer编码就是将一棵无根树映射到一串编码上的编码方法,一棵 \(n\) 个节点的树与一个长度为 \(n - 2\) 的编码串一一对应。所以我们要求合法的 = 总数 \(n ^ {n - 2}\) - 不合法的方案数。不合法的方案数 = 至少有 \(1\) 个不合法 - 至少有 \(2\) 个不合法 + 至少有\(3\) 个不合法……有何求出至少有 \(k\) 个不合法的方案数呢?
我们可以首先搜索出这 \(k\) 个限制(复杂度约为 \(2^{17}\)),然后令这\(k\) 个限制的 \(sum = \sum d[i] - 1\),\(sum\) 即为这 \(k\) 个限制中所牵涉到的节点在数列中一共应该出现的次数。满足这个限制(每一个节点出现 \(d[i] - 1\) 次)的数列个数即为 \(\frac{sum!}{\prod (d[i] - 1)!}\)。又因为这 \(sum\) 个数可以出现在长度为 \(n - 2\) 的数列中的任何位置,所以 乘上\(C(n - 2, sum)\),剩下的 \(n - 2 - sum\) 个数则可以随便选择,有 \((n - k) ^ {n - sum - 2}\) 种方案。完美~
#include <bits/stdc++.h> using namespace std; #define maxn 2000000 #define int long long #define mod 1000000007 int n, m, fac[maxn]; int cnt, tot = 1, Ans = 1, S[maxn]; bool mark[maxn], vis[maxn]; int read() { int x = 0, k = 1; char c; c = getchar(); while(c < ‘0‘ || c > ‘9‘) { if(c == ‘-‘) k = -1; c = getchar(); } while(c >= ‘0‘ && c <= ‘9‘) x = x * 10 + c - ‘0‘, c = getchar(); return x * k; } struct node { int x, d; }Q[maxn]; int Qpow(int x, int timer) { int base = 1; for(; timer; timer >>= 1, x = x * x % mod) if(timer & 1) base = base * x % mod; return base; } int C(int n, int m) { if(m > n) return 0; return fac[n] * Qpow(fac[m], mod - 2) % mod * Qpow(fac[n - m], mod - 2) % mod; } void Search(int now, int last) { if(m - last + 1 < now) return; if(!now) { int tem = 1, sum = 0; for(int i = 1; i <= cnt; i ++) tem = tem * (fac[Q[S[i]].d - 1]) % mod, sum += Q[S[i]].d - 1; if(sum > n - 2) return; tem = Qpow(tem, mod - 2); tem = tem * fac[sum] % mod * C(n - 2, sum) % mod * Qpow(n - cnt, n - sum - 2) % mod; if(cnt & 1) Ans = (Ans - tem + mod) % mod; else Ans = (Ans + tem) % mod; return; } for(int i = last; i <= m; i ++) { if(vis[Q[i].x]) continue; S[++ cnt] = i, vis[Q[i].x] = 1; Search(now - 1, i + 1); cnt --, vis[Q[i].x] = 0; } } signed main() { n = read(), m = read(); for(int i = 1; i <= m; i ++) Q[i].x = read(), Q[i].d = read(); if(n - 2 > 0) Ans = Qpow(n, n - 2); else { printf("1\n"); return 0; } fac[0] = 1; for(int i = 1; i <= n; i ++) fac[i] = fac[i - 1] * i % mod; for(int i = 1; i <= m; i ++) Search(i, 1); printf("%lld\n", Ans); return 0; }
原文:https://www.cnblogs.com/twilight-sx/p/9406260.html