给出一个字符串\(s\),问\(s\)中有多少个子串是有\(k\)个相同串拼接而成的
有个暴力的做法,枚举循环节长度然后哈希,复杂度\(O(\frac{n^2}{k})\)
现在考虑\(O(n\log{n})\)的做法,同样是枚举循环节长度\(len\),然后我们枚举循环节的起始位置,假设现在起始位置是\(pos\),那么我们先找从\(pos\)开始的\(k-1\)个循环节,需要保证每个循环节\(lcp(pos,pos+i\cdot len)\ge len\)
现在有了\(k-1\)个循环节,位置从\(L(pos)\),到\(R(pos+(k-1)\cdot len-1)\),现在要找符合条件的\(k\)循环子串,我们只要知道\(L\)和\(R+1\)的最长公共前缀\(lcp\)和以\(R\)结尾和以\(L+1\)结尾的最长公共后缀\(lcs\),就能知道这个\(k-1\)循环节对答案的贡献,我们显然可以构造一个字符串\(s_{l-pre}s_{l-pre+1}\cdots s_l s_{l+1}\cdots s_r s_{r+1}\cdots s_{r+suf}\),其中\(pre+suf==len && pre\le lcs && suf\le lcp\),所以当\(lcs+lcp\ge len\)的时候,对答案的贡献是\(lcs+lcp-len+1\),注意左边界和右边界的处理情况,有点小细节,防止重复计算
由于枚举循环节,复杂度为调和级数\(O(\sum_{i=1}^{n}\frac{n}{i})=O(n\log n)\)
//#pragma GCC optimize("O3")
//#pragma comment(linker, "/STACK:1024000000,1024000000")
#include<bits/stdc++.h>
using namespace std;
function<void(void)> ____ = [](){ios_base::sync_with_stdio(false); cin.tie(0); cout.tie(0);};
const int MAXN = 3e5+7;
using LL = int_fast64_t;
LL ret;
int n,K,rk[2][MAXN],sec[MAXN],c[MAXN],sa[2][MAXN],height[2][MAXN],ST[2][MAXN][20];
char s[MAXN],t[MAXN];
void SA(int m, char *s, int *rk, int *sa, int *height){
int *RK = rk, *SEC = sec;
for(int i = 0; i <= m; i++) c[i] = 0;
for(int i = 1; i <= n; i++) c[RK[i]=s[i]]++;
for(int i = 1; i <= m; i++) c[i] += c[i-1];
for(int i = n; i >= 1; i--) sa[c[RK[i]]--] = i;
for(int k = 1; k <= n; k <<= 1){
int p = 0;
for(int i = n - k + 1; i <= n; i++) SEC[++p] = i;
for(int i = 1; i <= n; i++) if(sa[i]>k) SEC[++p] = sa[i]-k;
for(int i = 0; i <= m; i++) c[i] = 0;
for(int i = 1; i <= n; i++) c[RK[SEC[i]]]++;
for(int i = 1; i <= m; i++) c[i] += c[i-1];
for(int i = n; i >= 1; i--) sa[c[RK[SEC[i]]]--] = SEC[i];
swap(RK,SEC);
p = 0;
RK[sa[1]] = ++p;
for(int i = 2; i <= n; i++) RK[sa[i]] = SEC[sa[i]]==SEC[sa[i-1]] and SEC[sa[i]+k]==SEC[sa[i-1]+k] ? p : ++p;
if(p==n) break;
m = p;
}
int k = 0;
for(int i = 1; i <= n; i++) rk[sa[i]] = i;
for(int i = 1; i <= n; i++){
if(rk[i]==1) continue;
if(k) k--;
int j = sa[rk[i]-1];
while(i+k<=n and j+k<=n and s[i+k]==s[j+k]) k++;
height[rk[i]] = k;
}
}
void build_ST(){
for(int i = 1; i <= n; i++){
ST[0][i][0] = height[0][i];
ST[1][i][0] = height[1][i];
}
for(int j = 1; (1<<j) <= n; j++){
for(int i = 1; (i+(1<<j))-1 <= n; i++){
ST[0][i][j] = min(ST[0][i][j-1],ST[0][i+(1<<(j-1))][j-1]);
ST[1][i][j] = min(ST[1][i][j-1],ST[1][i+(1<<(j-1))][j-1]);
}
}
}
int qmin(int tg, int L, int R){
int d = log2(R-L+1);
return min(ST[tg][L][d],ST[tg][R-(1<<d)+1][d]);
}
int lcp(int tg, int l, int r){
int rkl = rk[tg][l], rkr = rk[tg][r];
if(rkl>rkr) swap(rkl,rkr);
return qmin(tg,rkl+1,rkr);
}
void calc(int pos, int len){
for(int i = 1; i < K - 1; i++) if(lcp(0,pos,pos+i*len)<len) return;
int L = pos, R = L + (K-1) * len - 1;
int LCP = min(len,lcp(0,L,R+1));
int LCS = min(len-1,lcp(1,n+1-R,n+1-L+1));
if(LCP+LCS>=len) ret += LCP + LCS - len + 1;
}
void solve(){
ret = 0;
scanf("%d %s",&K,s+1);
n = strlen(s+1);
if(K==1){
printf("%I64d\n",1ll*n*(n+1)/2);
return;
}
for(int i = 1; i <= n; i++) t[i] = s[n+1-i];
SA(128,s,rk[0],sa[0],height[0]);
SA(128,t,rk[1],sa[1],height[1]);
build_ST();
for(int len = 1; len <= n; len++){
for(int i = 1; i <= n; i += len){
if(i+(K-1)*len-1>=n) break;
calc(i,len);
}
}
printf("%I64d\n",ret);
}
int main(){
int T;
for(scanf("%d",&T); T; T--) solve();
return 0;
}
HDU6661 Acesrc and String Theory【SA】
原文:https://www.cnblogs.com/kikokiko/p/12781777.html