代码如下
#include <bits/stdc++.h>
using namespace std;
typedef long long LL;
const LL mod = 998244353, g = 3, ig = 332748118;
inline LL fpow(LL a, LL b) {
LL ret = 1 % mod;
for (; b; b >>= 1, a = a * a % mod) {
if (b & 1) {
ret = ret * a % mod;
}
}
return ret;
}
void ntt(vector<LL> &v, vector<int> &rev, int opt) {
int tot = v.size();
for (int i = 0; i < tot; i++) {
if (i < rev[i]) {
swap(v[i], v[rev[i]]);
}
}
for (int mid = 1; mid < tot; mid <<= 1) {
LL wn = fpow(opt == 1 ? g : ig, (mod - 1) / (mid << 1));
for (int j = 0; j < tot; j += mid << 1) {
LL w = 1;
for (int k = 0; k < mid; k++) {
LL x = v[j + k], y = v[j + mid + k] * w % mod;
v[j + k] = (x + y) % mod, v[j + mid + k] = (x - y + mod) % mod;
w = w * wn % mod;
}
}
}
if (opt == -1) {
LL itot = fpow(tot, mod - 2);
for (int i = 0; i < tot; i++) {
v[i] = v[i] * itot % mod;
}
}
}
void solve(int n, vector<LL> &a, vector<LL> &b) {
if (n == 1) {
b[0] = fpow(a[0], mod - 2);
return;
}
int mid = (n + 1) >> 1;
solve(mid, a, b);
int bit = 0, tot = 1;
while (tot <= 2 * n) {
tot <<= 1;
bit++;
}
vector<int> rev(tot);
for (int i = 0; i < tot; i++) {
rev[i] = (rev[i >> 1] >> 1) | (i & 1) << (bit - 1);
}
vector<LL> foo(tot), bar(tot);
for (int i = 0; i < n; i++) {
foo[i] = a[i];
}
for (int i = 0; i < mid; i++) {
bar[i] = b[i];
}
ntt(foo, rev, 1), ntt(bar, rev, 1);
for (int i = 0; i < tot; i++) {
bar[i] = bar[i] * (2 - foo[i] * bar[i] % mod + mod) % mod;
}
ntt(bar, rev, -1);
for (int i = 0; i < n; i++) {
b[i] = bar[i];
}
}
int main() {
ios::sync_with_stdio(false);
cin.tie(0), cout.tie(0);
int n;
cin >> n;
vector<LL> a(n), b(n);
for (int i = 0; i < n; i++) {
cin >> a[i];
}
solve(n, a, b);
for (int i = 0; i < n; i++) {
cout << b[i] << " ";
}
return 0;
}
原文:https://www.cnblogs.com/wzj-xhjbk/p/11601151.html