诶,开这篇的原因之一是要学习多项式全家桶,原因之二是发现自己以前的一些模板已经非常不熟悉了,甚至于一些细节已经不知道为什么这么写了。
for (int len = 2; len <= lim; len <<= 1) {
int half = len >> 1, w = (type > 0) ? g[cnt] : invg[cnt];
for (int bg = 0; bg < lim; bg += len) {
int wn = 1;
for (int pos = bg; pos < bg + half; ++pos) {
int tmp = Mul(wn, poly[pos + half]);
poly[pos + half] = Dec(poly[pos], tmp);
poly[pos] = Add(poly[pos], tmp);
wn = Mul(wn, w);
}
}
++cnt;
}
我将对上面这段代码,我产生过的疑问进行记录。
g[cnt]
?我怎么知道是?
g[x]
能解决的多项式项数是 \(2^x\) 次多项式的点值转化。完整代码:
#include <cstdio>
#include <iostream>
using namespace std;
typedef long long ll;
const int Mod = 998244353, G = 3, N = 1e6 + 5;
inline int Rd() {
int ret = 0, fu = 1;
char ch = getchar();
while (!isdigit(ch)) {
if (ch == ‘-‘)
fu = -1;
ch = getchar();
}
while (isdigit(ch))
ret = ret * 10 + (ch - ‘0‘), ch = getchar();
return ret * fu;
}
inline int Add(int x, int y) { return (x + y > Mod) ? (x + y - Mod) : (x + y); }
inline int Dec(int x, int y) { return (x - y < 0) ? (x - y + Mod) : (x - y); }
inline int Mul(int x, int y) { return 1ll * x * y % Mod; }
inline int Pow(int x, int y) {
int ret = 1;
for (; y; y >>= 1, x = Mul(x, x))
if (y & 1)
ret = Mul(ret, x);
return ret;
}
int rev[N * 4], lim, g[25], invg[25];
void Init(int n, int m) {
lim = 1;
while (lim <= n + m)
lim <<= 1;
for (int i = 0; i < lim; ++i)
rev[i] = (rev[i >> 1] >> 1) | ((i & 1) ? (lim >> 1) : 0);
for (int i = 1; (1 << i) <= lim; ++i) {
g[i] = Pow(G, (Mod - 1) / (1 << i));
invg[i] = Pow(g[i], Mod - 2);
}
}
void NTT(int *poly, int type) {
for (int i = 0; i < lim; ++i)
if (rev[i] > i)
swap(poly[i], poly[rev[i]]);
int cnt = 1;
for (int len = 2; len <= lim; len <<= 1) {
int half = len >> 1, w = (type > 0) ? g[cnt] : invg[cnt];
for (int bg = 0; bg < lim; bg += len) {
int wn = 1;
for (int pos = bg; pos < bg + half; ++pos) {
int tmp = Mul(wn, poly[pos + half]);
poly[pos + half] = Dec(poly[pos], tmp);
poly[pos] = Add(poly[pos], tmp);
wn = Mul(wn, w);
}
}
++cnt;
}
}
int n, m, A[N * 4], B[N * 4];
int main() {
n = Rd(), m = Rd();
for (int i = 0; i <= n; ++i)
A[i] = (Rd() + Mod) % Mod;
for (int i = 0; i <= m; ++i)
B[i] = (Rd() + Mod) % Mod;
Init(n, m);
NTT(A, 1), NTT(B, 1);
for (int i = 0; i < lim; ++i)
A[i] = Mul(A[i], B[i]);
NTT(A, -1);
int inv = Pow(lim, Mod - 2);
for (int i = 0; i <= n + m; ++i)
printf("%d ", Mul(A[i], inv));
printf("\n");
return 0;
}
请推完式子再写NTT,否则容易写错!对着式子写。
#include <cstdio>
#include <iostream>
#include <cmath>
using namespace std;
const int N = 3e6 + 5;
const double PI = acos(-1);
struct Complex {
double x, y;
Complex(double _x, double _y) : x(_x), y(_y) {}
Complex() : x(0), y(0) {}
Complex operator+(const Complex &d) const { return Complex(x + d.x, y + d.y); }
Complex operator-(const Complex &d) const { return Complex(x - d.x, y - d.y); }
Complex operator*(const Complex &d) const { return Complex(x * d.x - y * d.y, x * d.y + y * d.x); }
} f[N], g[N]; //x为实部,y为虚部
int n, m, rev[N];
void FFT(Complex *poly, int lim, int type) {
for (int i = 0; i < lim; ++i) if (i < rev[i]) swap(poly[i], poly[rev[i]]);
for (int len = 2; len <= lim; len <<= 1) {
int half = len >> 1;
Complex gen(cos(2 * PI / len), sin(2 * PI / len) * type);
//这个地方不是 /half而是 /len
for (int bg = 0; bg < lim; bg += len) {
Complex omg(1, 0);
for (int pos = bg; pos < bg + half; ++pos) {
Complex tmp = omg * poly[pos + half];
poly[pos + half] = poly[pos] - tmp;
//这里是 pos + half,没有 -1
poly[pos] = poly[pos] + tmp;
omg = omg * gen;
}
}
}
}
int main() {
scanf("%d%d", &n, &m);
for (int i = 0; i <= n; ++i) scanf("%lf", &f[i].x);
for (int i = 0; i <= m; ++i) scanf("%lf", &g[i].x);
int lim = 1, len = n + m + 1;
while (lim < len) lim = lim << 1;
for (int i = 1; i < lim; ++i) rev[i] = (rev[i >> 1] >> 1) | ((i & 1) ? (lim >> 1) : 0);
FFT(f, lim, 1);
FFT(g, lim, 1);
for (int i = 0; i < lim; ++i) f[i] = f[i] * g[i];
FFT(f, lim, -1);
for (int i = 0; i <= n + m; ++i) printf("%d ", (int) (f[i].x / lim + 0.49));
return 0;
}
原文:https://www.cnblogs.com/skiceanAKacniu/p/13172594.html