简介:目前网上关于多项式操作的文章和模板大多仍然是朴素的实现,常数巨大,这个技巧利用循环卷积优化多项式操作的常数。
(这个trick在N年前就有了,「循环卷积优化」这个名字是我瞎起的,如果有人知道这个trick的名字请联系我。)
求 \(B(x)\) 满足 \(A(x)B(x) \equiv 1 \pmod n\)。
牛顿迭代可得:
朴素的实现需要做3次长度为 \(2^{t+2}\) 的FFT,把多余的部分舍去,常数较大。
发现 \(B_{t+1}(x)\) 的前 \(2^t\) 项和 \(B_t(x)\) 一样,所以只需要求后 \(2^t\) 项,即求 \(A(x)B_t^2(x)\) 的 \(x^{2^t}\dots x^{2^{t+1}-1}\) 项系数。
设
这个结果的前 \(2^t\) 项确定是1了,所以只有后半部分是有用的。
由于 \(\deg B_t = 2^t\),如果做长度为 \(2^{t+1}\) 的卷积,多余的部分会循环到前半部分,不会影响后半部分的结果。
同样的, \(A(x)B_t^2(x) = A(x)B_t(x) \times B_t(x)\) ,卷积多余的部分会循环到前 \(2^t\) 项,后半部分不会受到影响。
所以只需要做5次长度为 \(2^{t+1}\) 的FFT,在实际测试中(用100000的数据测试)常数约为正常写法的三分之二。
下面这个是递归写法:
Polynom inverse(Polynom a) {
int n = a.size();
assert((n & n - 1) == 0);
if (n == 1) return {fpow(a[0])};
int m = n >> 1;
Polynom b = inverse(Polynom(a.begin(), a.begin() + m)), c = b;
b.resize(n);
dft(a), dft(b);
for (int i = 0; i < n; i++) a[i] = 1LL * a[i] * b[i] % P;
idft(a);
for (int i = 0; i < m; i++) a[i] = 0;
for (int i = m; i < n; i++) a[i] = P - a[i];
dft(a);
for (int i = 0; i < n; i++) a[i] = 1LL * a[i] * b[i] % P;
idft(a);
for (int i = 0; i < m; i++) a[i] = c[i];
return a;
}
其他操作的分析后续可能会更新,先贴个开根的代码。
Polynom sqrt(Polynom a) { // return-value: \sqrt{a}
int len = a.size();
assert((len & len - 1) == 0);
assert(a[0] == 1); // warning: sqrtMod is needed if a[0] > 1.
Polynom b(len), binv{1}, bsqr{1}; // sqrt, sqrt_inv, sqrt_sqr
Polynom foo, bar; // temp
b[0] = 1;
auto shift = [](int x) { return (x & 1 ? x + P : x) >> 1; }; // quick div 2
for (int m = 1, n = 2; n <= len; m <<= 1, n <<= 1) {
foo.resize(n), bar = binv;
for (int i = 0; i < m; i++) {
foo[i + m] = sub(sum(a[i], a[i + m]), bsqr[i]);
foo[i] = 0;
}
binv.resize(n);
dft(foo), dft(binv);
for (int i = 0; i < n; i++) foo[i] = 1LL * foo[i] * binv[i] % P;
idft(foo);
for (int i = m; i < n; i++) b[i] = shift(foo[i]);
// inv
if (n == len) break;
for (int i = 0; i < n; i++) foo[i] = b[i];
bar.resize(n), binv = bar;
dft(foo), dft(bar);
bsqr.resize(n);
for (int i = 0; i < n; i++) bsqr[i] = 1LL * foo[i] * foo[i] % P;
idft(bsqr);
for (int i = 0; i < n; i++) foo[i] = 1LL * foo[i] * bar[i] % P;
idft(foo);
for (int i = 0; i < m; i++) foo[i] = 0;
for (int i = m; i < n; i++) foo[i] = P - foo[i];
dft(foo);
for (int i = 0; i < n; i++) foo[i] = 1LL * foo[i] * bar[i] % P;
idft(foo);
for (int i = m; i < n; i++) binv[i] = foo[i];
}
return b;
}
原文:https://www.cnblogs.com/HolyK/p/13997531.html