题面好评,daknb
题意简述:给一个\(n-1\)次多项式,求它的\(m\)次幂的\(k\)次项系数\(bmod 2\)的结果。\(1\leq n\leq 3000, 1\leq m\leq10^9\)。
首先是一个性质:\(\bmod 2\)意义下,\(f^2(x)\)的\(i\)次项系数为\(1\),当且仅当\(i\bmod 2 =0\),且\(f(x)\)的\(\frac{i}{2}\)次项为\(1\)。
证明很简单,考虑\(i\not=j\)且\([x^i]f(x)=[x^j]f(x)=1\),则\([x^{i+j}]f^2(x)\)会同时被\([x^i]\cdot[x^j]\)和\([x^j]\cdot[x^i]\)影响,相互抵消为\(0\)。因此只有\([x^i]f(x)\)对\([x^{2i}]f^2(x)\)的影响才会被保留。
有了这个性质,我们可以简单求得\(f^{2^p}(x)\),并将\(m\)用\(\sum_{p=0}^{\infty}b_p\cdot 2^p, b_p\in \{0, 1\}\)表示。
发现有如下dp:设\(g(i, j)\)为从低到高考虑到\(f^{2^i}(x)\),二进制下第\(0, 1, \ldots, i\)位均与\(k\)的对应二进制位相同,且将第\(i+1\)个二进制位看作最低位之后的\(j\)次项的系数(即,在实际的多项式中,\(g(i, j)\)代表\(2^{i+1}\cdot j+k\bmod 2^{i+1}\)次项系数),\(w_i=\lfloor\frac{k}{2^{i}}\rfloor \bmod 2\),则:
\[g(i, j)= \begin{cases} [j=0], &i=-1\g(i-1, 2\cdot j+w), & i\not=0, \lfloor\frac{m}{2^i}\rfloor\bmod 2=0\\sum_{l=0}^{n-1} g(i-1, 2\cdot j+w-l)\cdot [x^l]f(x), & \text{otherwise} \end{cases} \]
其实\(\lfloor\frac{m}{2^i}\rfloor\bmod 2=0\)的情况等价于\(g\)卷积上了\(\epsilon(x)=[x=1]\),而\(\text{otherwise}\)等价于\(g\)卷积上了\(f^{2^i}(x)\),只是我们通过一些方式,只保留了我们需要的系数。最后再进行最低位与\(w_i\)相同的位置整体除以\(2\)的变换。
最终答案就是\(g(min(\log_2 k, \log_2 m), 0)\)。
直接dp,复杂度\(O(\log(n\cdot m)\cdot n^2)\),略显吃紧。发现系数全都是\(0/1\),用\rm{bitset}优化转移即可。
#include <cstdio>
#include <cctype>
#include <bitset>
#include <cstring>
#include <cassert>
#include <iostream>
#include <algorithm>
#define R register
#define ll long long
using namespace std;
const int N = 6100;
int t, n, m;
ll k;
bitset<N> f, tmp, a;
template <class T> inline void read(T &x) {
x = 0;
char ch = getchar(), w = 0;
while (!isdigit(ch)) w = (ch == '-'), ch = getchar();
while (isdigit(ch)) x = (x << 1) + (x << 3) + (ch ^ 48), ch = getchar();
x = w ? -x : x;
return;
}
int main() {
read(t);
while (t--) {
read(n), read(m), read(k);
a.reset(), f.reset();
for (R int i = 0, x; i < n; ++i)
read(x), a[i] = x;
if (!k) {
cout << (m ? a[0] : 1) << endl;
continue;
}
f[0] = 1;
for (R int i = 0; (1ll << i) <= max(k, (ll) m); ++i) {
tmp.reset();
if (m & (1ll << i)) {
for (R int j = 0; j < n; ++j)
if (a[j])
tmp ^= f << j;
}
else
tmp = f;
for (R int j = 0; j < n; ++j)
f[j] = tmp[(j << 1) | ((k >> i) & 1)];
}
cout << f[0] << endl;
}
return 0;
}
原文:https://www.cnblogs.com/suwakow/p/11679615.html