之前向学习了一个 $\text{FFT}$ 的优化,但是像我这么弱的人每次打 $\text{FFT}$ 板子的时候都会忘记这个东西,在这里记一下。
?我们知道普通的 $\text{FFT}$ 会用到原根 $\omega_n^0,\omega_n^1\cdots\omega_n^{n-1}$ 然后这些东西会在枚举步长的时候通过 $\omega_n = e^{\frac{2\pi}{n}}$ 和 $e^{\theta i} = \cos \theta + i\sin \theta$ 这两个公式一次一次算出来。
?然而我们知道,调用三角函数是非常慢的,每次计算的时候,即使你是手写的 $\text{complex}$ 也会非常慢,这就使得这种 $\text{FFT}$ 的常数巨大无比。
?所以我们就预处理一下每次需要用到的 $\omega$ ,把每一种步长需要用到的 $\omega$ 扔到同一个数组 $W$ 里,有每种步长的 $\omega$ 连续。而因为 $\sum_{i=0}^{n} 2^i = 2^{i + 1} - 1$ ,所以每次需要访问步长为 $s$ 的 $\omega$ 时候只要访问 $W[s]$ 就可以了,将一个指针指向他,而后面的只要把指针一步一步往后移即可。
?这是 $\text{DFT}$ 的时候用的,但是我们知道 $\text{IDFT}$ 的时候用的 $\omega$ 和 $\text{DFT}$ 的时候是不一样的。
?然而我们不需要重新处理 $\text{IDFT}$ 用的 $\omega$ ,只需要把需要 $\text{FFT}$ 的 $A$ 从 $1$ 到 $n - 1$ 的值 $\text{reverse}$ 一下就行了。原理是本来 $\text{IDFT}$ 的时候需要把 $\omega$ 翻过来,但是那个有点麻烦,于是我们就把 $A$ 给翻过来就行了。由于 $\text{FFT}$ 可以被理解为一个特殊的矩阵乘法,所以你顺着搞下来和反着搞回去最后的结果是一样的,所以它是对的。
?然后下面贴了一道水题的代码来帮助理解:
例:?求有多少个从 $1,2,\cdots,n$ 中取三个元素的排列 $(a,b,c)$ 满足 $x_a=x_b-x_c$。?由于是排列,所以 $(a,b,c)$ 与 $(c,b,a)$ 视为两组解。
#include <algorithm> #include <cstdio> #include <cmath> #include <cstdlib> #include <cstring> #include <ctime> #include <iostream> #include <queue> #include <set> #include <stack> #define R register #define ll long long #define db double #define ld long double #define sqr(_x) (_x) * (_x) #define Cmax(_a, _b) ((_a) < (_b) ? (_a) = (_b), 1 : 0) #define Cmin(_a, _b) ((_a) > (_b) ? (_a) = (_b), 1 : 0) #define Max(_a, _b) ((_a) > (_b) ? (_a) : (_b)) #define Min(_a, _b) ((_a) < (_b) ? (_a) : (_b)) #define Abs(_x) (_x < 0 ? (-(_x)) : (_x)) using namespace std; namespace Dntcry { inline int read() { R int a = 0, b = 1; R char c = getchar(); for(; c < ‘0‘ || c > ‘9‘; c = getchar()) (c == ‘-‘) ? b = -1 : 0; for(; c >= ‘0‘ && c <= ‘9‘; c = getchar()) a = (a << 1) + (a << 3) + c - ‘0‘; return a * b; } inline ll lread() { R ll a = 0, b = 1; R char c = getchar(); for(; c < ‘0‘ || c > ‘9‘; c = getchar()) (c == ‘-‘) ? b = -1 : 0; for(; c >= ‘0‘ && c <= ‘9‘; c = getchar()) a = (a << 1) + (a << 3) + c - ‘0‘; return a * b; } const int Maxn = 1000010, Maxl = 600010, lim = 100000; const ld pi = acos(-1); struct Complex { ld real, imag; Complex operator + (const Complex &b) const { return (Complex) {real + b.real, imag + b.imag}; } Complex operator - (const Complex &b) const { return (Complex) {real - b.real, imag - b.imag}; } Complex operator * (const Complex &b) const { return (Complex) {real * b.real - imag * b.imag, b.real * imag + real * b.imag}; } }C[Maxl], A[Maxl], w[Maxl], wl; int n, m, x[Maxn], Cnt[Maxl], len, bit, rev[Maxl], zero; ll Ans[Maxn], Sum; void Get_Rev(R int bit) { for(R int i = 0; i < len; i++) rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << bit - 1); return ; } void FFT(R Complex *K, R ld DFT) { for(R int i = 0; i < len; i++) if(i < rev[i]) swap(K[i], K[rev[i]]) R Complex *W; for(R int i = 2; i <= len; i <<= 1) { for(R int j = 0, step = i >> 1; j < len; j += i) { W = w + step; for(R int k = j; k < j + step; W++, k++) { R Complex G = K[k], H = *W * K[k + step]; K[k] = G + H; K[k + step] = G - H; } } } if(DFT == -1.0) for(R int i = 0; i < len; i++) K[i].real /= 1.0 * len, K[i].imag /= 1.0 * len; return ; } int Main() { n = read(); for(R int i = 1; i <= n; i++) { x[i] = read(); if(!x[i]) zero++; x[i] += lim, m = Max(m, x[i]); Cnt[x[i]]++; } m++; for(bit = 0, len = 1; (1 << bit) < (m << 1); bit++) len <<= 1; R int tmp = len >> 1; w[tmp] = (Complex) {1.0, 0.0}; wl = w[++tmp] = (Complex) {cos(2.0 * pi / len), sin(2.0 * pi / len)}; for(tmp++; tmp < len; tmp++) w[tmp] = w[tmp - 1] * wl; for(R int i = (len >> 1) - 1; i; i--) w[i] = w[i << 1]; Get_Rev(bit); for(R int i = 0; i < m; i++) A[i] = (Complex) {1.0 * Cnt[i], 0.0}; FFT(A, 1.0); C[0] = A[0] * A[0]; for(R int i = 1; i < len; i++) C[i] = A[len - i] * A[len - i]; FFT(C, -1.0); for(R int i = 0; i < len; i++) Ans[i] = (ll)(C[i].real + 0.5); for(R int i = 1; i <= n; i++) Ans[x[i] << 1]--; for(R int i = 1; i <= n; i++) Sum += Ans[x[i] + lim]; Sum -= 2ll * zero * (n - 1); printf("%lld\n", Sum); return 0; } } int main() { return Dntcry :: Main(); }
原文:https://www.cnblogs.com/DntcryBecthlev/p/10448196.html