长度为 $2$ 的可以直接求,之后就不把 $2$ 当质数了。
考虑一对 $(i,j)$ 会出现在多少个区间里,其中 $|j-i+1|\in \mathbb{P}$
当 $i+j\leq n +1$ 时,一对 $(i,j)$ 会出现在 $[|i-j|+1,i+j]$ 这个区间里面所有质数长度里面,记 $s_i$ 表示 $[1,i]$ 里有多少个质数,那么贡献就是 $a_i \times a_j \times (s_{i+j}-s_{|i-j|})$
当 $i+j > n + 1$ 时,一对 $(i,j)$ 会出现在 $[|i-j|+1,2n-(i+j)-1]$ 这个区间里面所有质数长度里面,贡献即为 $a_i \times a_j \times (s_{2n-(i+j)-1}-s_{|i-j|})$
$|i-j$ 可以把另一个序列翻转一下就变成正常的卷积了。
#include <bits/stdc++.h> using db = long double; const int N = (1 << 19) + 1; const int MOD = 1e9 + 7; int prime[N], prin, s[N]; bool vis[N]; void M(int &a) { if (a >= MOD) a -= MOD; if (a < 0) a += MOD; } namespace FFT { const db pi = acos(-1.0); struct Complex { db r, i; Complex() {} Complex(db r, db i): r(r), i(i) {} Complex operator + (const Complex &p) const { return Complex(r + p.r, i + p.i); } Complex operator - (const Complex &p) const { return Complex(r - p.r, i - p.i); } Complex operator * (const Complex &p) const { return Complex(r * p.r - i * p.i, r * p.i + i * p.r); } } A[N], B[N]; int n, l, r[N]; void init(int m) { n = 1, l = 0; while (n <= m) n <<= 1, l++; for (int i = 0; i < n; i++) { r[i] = r[i >> 1] >> 1 | ((i & 1) << (l - 1)); } } void FFT(Complex *a, int pd) { for (int i = 0; i < n; i++) if (i < r[i]) std::swap(a[i], a[r[i]]); for (int mid = 1; mid < n; mid <<= 1) { Complex wn(cos(pi / mid), pd * sin(pi / mid)); for (int l = mid << 1, j = 0; j < n; j += l) { Complex w(1.0, 0.0); for (int k = 0; k < mid; k++, w = w * wn) { Complex u = a[k + j], v = w * a[k + j + mid]; a[k + j] = u + v; a[k + j + mid] = u - v; } } } if (pd == -1) for (int i = 0; i < n; i++) a[i] = Complex(a[i].r / n, a[i].i / n); } void Mul(Complex *a, Complex *b) { FFT(a, 1); FFT(b, 1); for (int i = 0; i < n; i++) a[i] = a[i] * b[i]; FFT(a, -1); } void solve(int *a, int m) { FFT::init(m * 2); int ans = 0; for (int i = 1; i < m; i++) M(ans += 2 * a[i] * a[i + 1]); for (int i = 1; i <= m; i++) A[i] = B[i] = Complex(1.0l * a[i], 0.0l); Mul(A, B); for (int i = 1; i <= 2 * m; i++) { if (i % 2 == 0) { int t = (long long)(A[i].r + 0.5l) % MOD; if (i <= m + 1) M(ans += 1LL * s[i - 1] * t % MOD); else M(ans += 1LL * s[2 * m - i + 1] * t % MOD); } } memset(A, 0, sizeof(A)); memset(B, 0, sizeof(B)); for (int i = 1; i <= m; i++) { A[i] = Complex(1.0l * a[i], 0.0l); B[i] = Complex(1.0l * a[m - i + 1], 0.0l); } Mul(A, B); for (int i = 1; i <= 2 * m; i++) { int len = std::abs(i - m - 1); if (len % 2 == 0) { int t = (long long)(A[i].r + 0.5l) % MOD; M(ans -= 1LL * t * s[len] % MOD); } } printf("%d\n", ans); } } void init(const int n) { for (int i = 2; i <= n; i++) { if (!vis[i]) prime[++prin] = i, s[i] = 1; for (int j = 1; j <= prin && i * prime[j] <= n; j++) { vis[i * prime[j]] = 1; if (i % prime[j] == 0) break; } } s[2] = 0; for (int i = 3; i <= n; i++) s[i] += s[i - 1]; } int a[N], n; int main() { scanf("%d", &n); for (int i = 1; i <= n; i++) scanf("%d", a + i); init(2 * n + 100); FFT::solve(a, n); return 0; }
原文:https://www.cnblogs.com/Mrzdtz220/p/12257668.html