第一步就不会
那么只要算出每个和为 \(A\) 的容斥系数的和就好了
这个可以直接 \(NTT\) 预处理 \(\prod_{i=2}^{n}(1-x^{w_i})\)
可以直接分治 \(FFT\) 求 \(\Theta(nlog^2n)\),也可以求 \(ln\) 之后 \(exp\) \(\Theta(nlogn)\)
垃圾exp大常数被分治FFT吊着打
# include <bits/stdc++.h>
using namespace std;
typedef long long ll;
const int maxn(1 << 18);
const int mod(998244353);
inline void Inc(int &x, int y) {
if ((x += y) >= mod) x -= mod;
}
inline int Pow(ll x, ll y) {
register ll ret = 1;
for (x %= mod, y %= mod - 1; y; y >>= 1, x = x * x % mod)
if (y & 1) ret = ret * x % mod;
return ret;
}
int w[2][maxn], r[maxn], l, deg;
inline void Init(int n) {
register int i, x, y;
for (deg = 1, l = 0; deg < n; deg <<= 1) ++l;
for (i = 0; i < deg; ++i) r[i] = (r[i >> 1] >> 1) | ((i & 1) << (l - 1));
w[1][0] = w[0][0] = 1, x = Pow(3, (mod - 1) / deg), y = Pow(x, mod - 2);
for (i = 1; i < deg; ++i) w[0][i] = (ll)w[0][i - 1] * x % mod, w[1][i] = (ll)w[1][i - 1] * y % mod;
}
inline void NTT(int *p, int opt) {
register int i, j, k, t, wn, x, y;
for (i = 0; i < deg; ++i) if (r[i] < i) swap(p[r[i]], p[i]);
for (i = 1; i < deg; i <<= 1)
for(t = i << 1, j = 0; j < deg; j += t)
for (k = 0; k < i; ++k) {
wn = w[opt == -1][deg / (i << 1) * k];
x = p[j + k], y = (ll)wn * p[i + j + k] % mod;
p[j + k] = x + y, p[i + j + k] = x - y;
if (p[j + k] >= mod) p[j + k] -= mod;
if (p[i + j + k] < 0) p[i + j + k] += mod;
}
if (opt == -1) {
for (wn = Pow(deg, mod - 2), i = 0; i < deg; ++i) p[i] = 1LL * p[i] * wn % mod;
}
}
void Inv(int *p, int *q, int len) {
if (len == 1) {
q[0] = Pow(p[0], mod - 2);
return;
}
Inv(p, q, len >> 1);
register int i, tmp = len << 1;
static int a[maxn], b[maxn];
for (Init(tmp), i = 0; i < tmp; ++i) a[i] = b[i] = 0;
for (i = 0; i < len; ++i) a[i] = p[i], b[i] = q[i];
for (NTT(a, 1), NTT(b, 1), i = 0; i < tmp; ++i) a[i] = (ll)a[i] * b[i] % mod * b[i] % mod;
for (NTT(a, -1), i = 0; i < len; ++i) q[i] = ((ll)q[i] + q[i] - a[i] + mod) % mod;
}
inline void Ln(int *p, int *q, int len) {
register int i, tmp = len << 1;
static int a[maxn], b[maxn];
for (Init(tmp), i = 0; i < tmp; ++i) a[i] = b[i] = 0;
for (Inv(p, b, len), i = 1; i < len; ++i) a[i - 1] = (ll)p[i] * i % mod;
for (NTT(a, 1), NTT(b, 1), i = 0; i < tmp; ++i) a[i] = (ll)a[i] * b[i] % mod;
for (NTT(a, -1), i = 0; i < len; ++i) q[i + 1] = (ll)a[i] * Pow(i + 1, mod - 2) % mod;
q[0] = q[len] = 0;
}
void Exp(int *p, int *q, int len) {
if (len == 1) {
q[0] = 1;
return;
}
Exp(p, q, len >> 1);
register int i, tmp = len << 1;
static int a[maxn], b[maxn];
for (Init(tmp), i = 0; i < tmp; ++i) a[i] = b[i] = 0;
for (Ln(q, a, len), i = 0; i < len; ++i) a[i] = (mod - a[i]) % mod;
for (Inc(a[0], 1), i = 0; i < len; ++i) Inc(a[i], p[i]), b[i] = q[i];
for (NTT(a, 1), NTT(b, 1), i = 0; i < tmp; ++i) a[i] = (ll)a[i] * b[i] % mod;
for (NTT(a, -1), i = 0; i < len; ++i) q[i] = a[i];
}
int n, cnt[maxn], mx, len, f[maxn], g[maxn], inv[maxn], ans, a[maxn];
int main() {
register int i, j;
for (scanf("%d%d", &n, &a[0]), i = 1; i < n; ++i) scanf("%d", &a[i]), ++cnt[a[i]], mx += a[i];
for (len = 1; len <= mx; len <<= 1);
for (inv[1] = 1, i = 2; i <= mx; ++i) inv[i] = (ll)(mod - mod / i) * inv[mod % i] % mod;
for (i = 1; i <= mx; ++i)
if (cnt[i])
for (j = i; j <= mx; j += i) Inc(f[j], mod - (ll)cnt[i] * inv[j / i] % mod);
for (Exp(f, g, len), i = 0; i <= mx; ++i)
if (g[i]) Inc(ans, (ll)g[i] * a[0] % mod * Pow(a[0] + i, mod - 2) % mod);
printf("%d\n", ans);
return 0;
}
原文:https://www.cnblogs.com/cjoieryl/p/10137943.html