题意:
给定两个字符串,求出在两个字符串中各取出一个子串使得这两个子串相同的方案数。两个方案不同当且仅当这两个子串中有一个位置不同。
题目已经给的很明白了好嘛?
串长\(n\leq 2\times 10^5\)
sol:
先考虑\(n^3\)暴力,分别枚举串1和串2后缀的左端点,然后暴力比较。
优化一下,求的是两个\(suf\)的\(LCP\),可以把俩串用别的字符隔开,然后做\(SA\)
查询\(O(1)\),总复杂度\(O(n^2)\)
继续观察,是结合的串的左半区间对右半区间每个点对\(rk\)映射以后的答案
也就是对于左右各半区间不能包含自己。
考虑子区间最小值的和减去左右各半子区间最小值的和,做三遍\(SA\)就完了。
子区间最小值的和可以单调栈维护,原理是一个点控制的最小值范围。
接近 BZOJ 3238这题。
#include<cctype>
#include<cstdio>
#include<cstring>
#include<algorithm>
const int N = 2e5 + 7;
const int lim = 2e5;
typedef long long LL;
inline int max(int a, int b){return a > b ? a : b;}
inline int min(int a, int b){return a > b ? b : a;}
int x[N*2], y[N*2], sa[N*2], rk[N*2], c[N*2];
int m = lim, n; char ssa[N*2], ssb[N*2];int het[N*2], qlog[N*2], st[N*2][23];
//#define R register
int ss[N*2];
inline void getSA() {
for (int i = 1; i <= m; i++) c[i] = 0;
for (int i = 1; i <= n; i++) c[x[i] = ss[i]]++;
for (int i = 1; i <= m; i++) c[i] += c[i - 1];
for (int i = 1; i <= n; i++) sa[c[x[i]]--] = i;
for (int siz = 1, p = 0; siz <= n; siz <<= 1, p = 0) {
for (int i = n - siz + 1; i <= n; i++) y[++p] = i;
for (int i = 1; i <= n; i++) if (sa[i] > siz) y[++p] = sa[i] - siz;
for (int i = 1; i <= m; i++) c[i] = 0;
for (int i = 1; i <= n; i++) c[x[y[i]]]++;
for (int i = 1; i <= m; i++) c[i] += c[i - 1];
for (int i = n; i >= 1; i--) sa[c[x[y[i]]]--] = y[i], y[i] = 0;
std :: swap(x, y), x[sa[1]] = 1, p = 1;
for (int i = 2; i <= n; i++) x[sa[i]] =
(y[sa[i]] == y[sa[i - 1]] && y[sa[i] + siz] == y[sa[i - 1] + siz]) ? p : ++p;
if (p == n) break; m = p;
} int k = 0;
memset(het, 0, sizeof(het));
for (int i = 1; i <= n; i++) rk[sa[i]] = i;
for (int i = 1; i <= n; i++) if (rk[i] != 1) {
if (k) k--;
int j = sa[rk[i] - 1];
while(j + k <= n && i + k <= n && ss[i + k] == ss[j + k]) k++;
het[rk[i]] = k;
} qlog[1] = 0;
memset(st, 0, sizeof(st));
for (int i = 2; i <= lim; i++) qlog[i] = qlog[i / 2] + 1;
for (int i = 1; i <= lim; i++) st[i][0] = het[i];
int logw = qlog[n];
for (int i = 1; i <= logw; i++) for (int j = 1; j + (1 << i) - 1 <= n; j++)
st[j][i] = min(st[j][i - 1], st[j + (1 << (i - 1))][i - 1]);
}
const int inf = 1e9 + 7;
inline int query(int x, int y) { if (x > y) std :: swap(x, y);
if(x == y) return inf; x += 1; int logw = qlog[y - x + 1], ans;
ans = min(st[x][logw], st[y - (1 << logw) + 1][logw]); return ans;
} int t;
int lena, lenb; int stack[N*2], top, L[N*2], R[N*2];
LL ans;
inline void solve1(int x) {
stack[++top] = 1;
for (int i = 2; i <= n; i++) {
while (top && het[stack[top]] > het[i])
R[stack[top--]] = i;
L[i] = stack[top], stack[++top] = i;
} while (top) R[stack[top--]] = n + 1;
for (int i = 1; i <= n; i++)
ans += x * (LL)(R[i] - i) * (LL)(i - L[i]) * (LL)het[i];
}
int main() {
scanf ("%s", ssa + 1), lena = strlen(ssa + 1);
scanf ("%s", ssb + 1), lenb = strlen(ssb + 1); n = lena + lenb + 1;
for (int i = 1; i <= n; i++) {
if (i <= lena) ss[i] = ssa[i];
else if (i == lena + 1) ss[i] = 155;
else if (i > lena + 1) ss[i] = ssb[i - lena - 1];
} m = 200, getSA();
solve1(1);
for (int i = 1; i <= lena; i++) ss[i] = ssa[i];
n = lena, m = 200; getSA(), solve1(-1);
for (int i = 1; i <= lenb; i++) ss[i] = ssb[i];
n = lenb, m = 200; getSA(), solve1(-1);
printf ("%lld", ans);
return 0;
}
原文:https://www.cnblogs.com/cjc030205/p/11638105.html