题面
https://www.luogu.org/problem/P1117
题解
#include<cstdio> #include<cstdlib> #include<cstring> #include<iostream> #include<algorithm> using namespace std; #define ll long long #define N 305000 int log[69500],st[N],en[N]; char ts[N],s[N]; struct SuffixArray { int n,m,rank[N],sa[N],tax[N],tp[N],height[N]; int minh[N][20]; char s[N]; void clear() { memset(s,0,sizeof(s)); memset(minh,0,sizeof(minh)); memset(rank,0,sizeof(rank)); memset(sa,0,sizeof(sa)); memset(tax,0,sizeof(tax)); memset(tp,0,sizeof(tp)); memset(height,0,sizeof(height)); } void cntsort() { for (int i=0;i<=m;i++) tax[i]=0; for (int i=1;i<=n;i++) tax[rank[i]]++; for (int i=1;i<=m;i++) tax[i]+=tax[i-1]; for (int i=n;i>=1;i--) sa[tax[rank[tp[i]]]--]=tp[i]; } void suffixsort() { m=30; for (int i=1;i<=n;i++) rank[i]=s[i]-‘a‘+1,tp[i]=i; cntsort(); for (int w=1,p=0;p<n;m=p,w<<=1) { p=0; for (int i=1;i<=w;i++) tp[++p]=n-w+i; for (int i=1;i<=n;i++) if (sa[i]>w) tp[++p]=sa[i]-w; cntsort(); swap(tp,rank); rank[sa[1]]=p=1; for (int i=2;i<=n;i++) rank[sa[i]]=(tp[sa[i-1]]==tp[sa[i]] && tp[sa[i-1]+w]==tp[sa[i]+w]) ? p:++p; } } void getheight() { int k=0; for (int i=1;i<=n;i++) { if (k) k--; int j=sa[rank[i]-1]; while (s[i+k]==s[j+k]) k++; height[rank[i]]=k; } } void build() { int i,j; for (i=1;i<=n;i++) minh[i][0]=height[i]; for (i=1;i<=16;i++) { int len=(1<<(i-1)); for (j=1;j<=n;j++) minh[j][i]=min(minh[j][i-1],minh[j+len][i-1]); } } int lcp(int x,int y) { x=rank[x],y=rank[y]; if (x>y) swap(x,y); if (x==y) return 0; int cnt=log[y-x]-1; return min(minh[x+1][cnt],minh[y-(1<<cnt)+1][cnt]); } } a,b; int main(){ int nx,x; log[0]=0; log[1]=1; x=1; for (int i=1;i<=15;i++) { nx=(1<<(i+1))-1; for (int j=x+1;j<=nx;j++) log[j]=i+1; x=nx; } int T; scanf("%d",&T); while (T--) { memset(st,0,sizeof(st)); memset(en,0,sizeof(en)); a.clear(); b.clear(); scanf("%s",s+1); int l=a.n=b.n=strlen(s+1); for (int i=1;i<=l;i++) a.s[i]=s[i]; a.s[l+1]=‘\0‘; for (int i=1;i<=l;i++) b.s[l-i+1]=s[i]; b.s[l+1]=‘\0‘; a.suffixsort(); a.getheight(); a.build(); b.suffixsort(); b.getheight(); b.build(); for (int L=1;L<=l/2;L++) { for(int i=L,j=i+L;j<=l;i+=L,j+=L) { int x=min(a.lcp(i,j),L),y=min(b.lcp(l-(i-1)+1,l-(j-1)+1),L-1); int t=x+y-L+1; if(x+y>=L) { st[i-y]++; st[i-y+t]--; en[j+x-t]++; en[j+x]--; } } } for (int i=1;i<=l;i++) st[i]+=st[i-1],en[i]+=en[i-1]; ll ans=0; for (int i=1;i<=l;i++) ans+=en[i]*1LL*st[i+1]; cout<<ans<<endl; } }
原文:https://www.cnblogs.com/shxnb666/p/11427263.html