之前做了那么多生成函数和多项式卷积的题目,结果今天才理解了优化卷积算法的实质。
首先我们以二进制FWT or作为最简单的例子入手。
我们发现正的FWT or变换就是求$\hat{a}_j=\sum_{i\in j}a_i$,即子集和,那这个是怎么来的呢?
我们假设$a$到$\hat{a}$的转移矩阵为$X$,则
$$(\sum_{j}X_{i,j}a_j)*(\sum_{j}X_{i,j}b_j)=\sum_jX_{i,j}(\sum_{s|t=j}a_sb_t)$$
所以考虑$a_sb_t$的贡献。
$$X_{i,s}*X_{i,t}=X_{i,s|t}$$
所以对于$X$的每一行都有$X_s*X_t=X_{s|t}$
而且由于最后还要进行逆变换,也就是乘上$X^{-1}$,我们知道矩阵可以求逆当且仅当$X$的行列式不为0,所以$X$的任意两行都不相同。
根据这个,我们先假设$X$中只有0和1(因为这样是最简单的),然后$X_{s|t}=1$与$X_s=X_t=1$等价,所以就可以推出来了。
先看$n=8$的情形。
$$X=\begin{pmatrix}1 & 0 & 0 & 0 & 0 & 0 & 0 & 0 \\1 & 1 & 0 & 0 & 0 & 0 & 0 & 0 \\1 & 0 & 1 & 0 & 0 & 0 & 0 & 0 \\1 & 1 & 1 & 1 & 0 & 0 & 0 & 0 \\1 & 0 & 0 & 0 & 1 & 0 & 0 & 0 \\1 & 1 & 0 & 0 & 1 & 1 & 0 & 0 \\1 & 0 & 1 & 0 & 1 & 0 & 1 & 0 \\1 & 1 & 1 & 1 & 1 & 1 & 1 & 1\end{pmatrix}$$
打表找规律可得
$$X_{i,j}=\prod_{k=0}^{n-1}C_{i[2^k],j[2^k]}$$
其中$i[2^k]$表示$i$在二进制下的第$k$位。
$$C=\begin{pmatrix}1 & 1 \\ 0 & 1\end{pmatrix}$$
然后我们就知道如何进行分治计算这个向量乘矩阵了。(对,就是那个三重循环)
我们也可以把FFT的矩阵也这样写出来。
$$A=\begin{pmatrix}\omega_n^0 & \omega_n^0 & \ldots & \omega_n^0 & \omega_n^0 \\\omega_n^0 & \omega_n^1 & \ldots & \omega_n^{n-2} & \omega_n^{n-1} \\\vdots & \vdots & \ddots & \ddots & \ddots \\\omega_n^0 & \omega_n^{n-1} & \ldots & \omega_n^{(n-2)(n-1)} & \omega_n^{(n-1)(n-1)}\end{pmatrix}$$
即$A_{i,j}=\omega_n^{ij}$,所以
$$A^{-1}=\frac{1}{n}\begin{pmatrix}\omega_n^{-0} & \omega_n^{-0} & \ldots & \omega_n^{-0} & \omega_n^{-0} \\\omega_n^{-0} & \omega_n^{-1} & \ldots & \omega_n^{-(n-2)} & \omega_n^{-(n-1)} \\\vdots & \vdots & \ddots & \ddots & \ddots \\\omega_n^{-0} & \omega_n^{-(n-1)} & \ldots & \omega_n^{-(n-2)(n-1)} & \omega_n^{-(n-1)(n-1)}\end{pmatrix}$$
即$A^{-1}_{i,j}=\frac{\omega_n^{-ij}}{m}$
我们设$B_{i,j}$表示$f_{i-1}$到$f_i$的转移矩阵。
定义$a\oplus b$表示三进制不进位加法,$a\ominus b$表示三进制不退位减法。易得这两个运算互为逆运算。
则$\forall k,B_{i\oplus k,j\oplus k}=B_{i,j}$,由数学归纳法得$\forall k,B_{i\oplus k,j\oplus k}^n=B_{i,j}^n$即$B_{i,j}^n=B_{0,j\ominus i}^n$
$$f_{n,i}=\sum_{j}f_{0,j}*B_{j,i}^n=\sum_{j}f_{0,j}*B_{0,i\ominus j}^n=\sum_{x\oplus y=i}f_x*B_{0,y}^n$$
所以我们只需要求出$B$矩阵的第一行并与$f_0$做三进制下的异或卷积就可以了。
我们先考虑二进制下的。
$$C=\begin{pmatrix}1 & 1 \\1 & -1\end{pmatrix}$$
($C$矩阵的意义见上)
所以感性理解一下(或者可以自己推一推),三进制的异或卷积的矩阵就是:
$$C=\begin{pmatrix}1 & 1 & 1 \\1 & \omega & \omega^2 \\1 & \omega^2 & \omega\end{pmatrix}$$
$$C^{-1}=\frac{1}{3}\begin{pmatrix}1 & 1 & 1 \\1 & \omega^2 & \omega \\1 & \omega & \omega^2\end{pmatrix}$$
其中$\omega=\frac{-1+\sqrt{3}i}{2}$
但是$\sqrt{3}$运算非常麻烦,还会有精度问题,所以我们取$1,\omega$作为基底而不是$1,i$,即把复数表示成$a+b\omega$的形式。
乘法与$a+bi$的乘法不一样,需要推一推。
$$(a+b\omega)(c+d\omega)=ac+(bc+ad)\omega+bd(-\omega-1)=(ac-bd)+(bc+ad-bd)\omega$$
然后就应该是做完了。
1 #include<cstdio> 2 #define Rint register int 3 using namespace std; 4 typedef long long LL; 5 const int N = 531441; 6 int n, m, t, p, po[13], cntx[N], cnty[N]; 7 inline void exgcd(int a, int b, int &x, int &y){ 8 if(!b){x = 1; y = 0; return;} 9 exgcd(b, a % b, y, x); y -= (LL) a / b * x; 10 } 11 struct complex { 12 int x, y; 13 inline complex(int x = 0, int y = 0): x(x), y(y){} 14 inline complex operator + (const complex &o) const {return complex((x + o.x) % p, (y + o.y) % p);} 15 inline complex operator - (const complex &o) const {return complex((x - o.x + p) % p, (y - o.y + p) % p);} 16 inline complex operator * (const complex &o) const { 17 return complex(((LL) x * o.x % p - (LL) y * o.y % p + p) % p, ((LL) y * o.x % p + (LL) x * o.y % p - (LL) y * o.y % p + p) % p); 18 } 19 } A[N], B[N]; 20 inline complex kasumi(complex a, int b){ 21 complex res = complex(1, 0); 22 while(b){ 23 if(b & 1) res = res * a; 24 a = a * a; 25 b >>= 1; 26 } 27 return res; 28 } 29 inline complex calc1(const complex &a){return complex((p - a.y) % p, (a.x - a.y + p) % p);} 30 inline complex calc2(const complex &a){return complex((a.y - a.x + p) % p, (p - a.x) % p);} 31 inline void dft(complex *A){ 32 for(Rint mid = 1;mid < n;mid *= 3) 33 for(Rint j = 0;j < n;j += mid * 3) 34 for(Rint k = 0;k < mid;k ++){ 35 complex x = A[j + k], y = A[j + k + mid], z = A[j + k + mid * 2]; 36 A[j + k] = x + y + z; 37 A[j + k + mid] = x + calc1(y) + calc2(z); 38 A[j + k + mid * 2] = x + calc1(z) + calc2(y); 39 } 40 } 41 inline void idft(complex *A){ 42 for(Rint mid = 1;mid < n;mid *= 3) 43 for(Rint j = 0;j < n;j += mid * 3) 44 for(Rint k = 0;k < mid;k ++){ 45 complex x = A[j + k], y = A[j + k + mid], z = A[j + k + mid * 2]; 46 A[j + k] = x + y + z; 47 A[j + k + mid] = x + calc1(z) + calc2(y); 48 A[j + k + mid * 2] = x + calc1(y) + calc2(z); 49 } 50 } 51 int trans[13][13]; 52 int main(){ 53 scanf("%d%d%d", &m, &t, &p); 54 po[0] = 1; 55 for(Rint i = 1;i <= m;i ++) po[i] = (LL) po[i - 1] * 3; 56 n = po[m]; 57 for(Rint i = 1;i <= m;i ++) po[i] %= p; 58 if(p == 1){ 59 for(Rint i = 0;i < n;i ++) puts("0"); 60 return 0; 61 } 62 for(Rint i = 0;i < n;i ++) scanf("%d", &A[i].x); 63 for(Rint i = 0;i <= m;i ++) 64 for(Rint j = 0;i + j <= m;j ++) scanf("%d", trans[i] + j); 65 for(Rint i = 0;i < n;i ++){ 66 cntx[i] = cntx[i / 3] + (i % 3 == 1); 67 cnty[i] = cnty[i / 3] + (i % 3 == 2); 68 B[i].x = trans[cntx[i]][cnty[i]]; 69 //printf("%d ", B[i].x); 70 } 71 //putchar(‘\n‘); 72 dft(A); dft(B); 73 //for(Rint i = 0;i < n;i ++) printf("(%d, %d)\n", A[i].x, A[i].y); 74 //for(Rint i = 0;i < n;i ++) printf("(%d, %d)\n", B[i].x, B[i].y); 75 for(Rint i = 0;i < n;i ++) A[i] = A[i] * kasumi(B[i], t); 76 idft(A); 77 int inv, tmp; 78 exgcd(n, p, inv, tmp); 79 inv = (inv + p) % p; 80 for(Rint i = 0;i < n;i ++) 81 printf("%d\n", (LL) A[i].x * inv % p); 82 }
原文:https://www.cnblogs.com/AThousandMoons/p/10926924.html