并不是很难啊,把细节想好了再写就很轻松了~
code:
#include <bits/stdc++.h> #define N 200003 #define LL long long #define setIO(s) freopen(s".in","r",stdin) ,freopen(s".out","w",stdout) using namespace std; struct SAM { int tot,last; int ch[N<<1][10],pre[N<<1],len[N<<1]; void init() { last=tot=1;} void extend(int c) { int np=++tot,p=last; len[np]=len[p]+1,last=np; for(;p&&!ch[p][c];p=pre[p]) ch[p][c]=np; if(!p) pre[np]=1; else { int q=ch[p][c]; if(len[q]==len[p]+1) pre[np]=q; else { int nq=++tot; len[nq]=len[p]+1; memcpy(ch[nq],ch[q],sizeof(ch[q])); pre[nq]=pre[q],pre[q]=pre[np]=nq; for(;p&&ch[p][c]==q;p=pre[p]) ch[p][c]=nq; } } } }s1,s2; int tot; char str[N]; int lson[N*30],rson[N*30],sum[N*30],rt[N<<1],tax[N<<1],rk[N<<1]; int newnode() { return ++ tot; } void update(int &x,int l,int r,int p) { if(!x) x=newnode(); ++sum[x]; if(l==r) return ; int mid=(l+r)>>1; if(p<=mid) update(lson[x],l,mid,p); else update(rson[x],mid+1,r,p); } int query(int x,int l,int r,int L,int R) { if(!x) return 0; if(l>=L&&r<=R) { return sum[x]; } int mid=(l+r)>>1, re=0; if(L<=mid) re+=query(lson[x],l,mid,L,R); if(R>mid) re+=query(rson[x],mid+1,r,L,R); return re; } int merge(int x,int y) { if(!x||!y) return x+y; int now=newnode(); sum[now]=sum[x]+sum[y]; lson[now]=merge(lson[x],lson[y]); rson[now]=merge(rson[x],rson[y]); return now; } int getmin(int x,int l,int r) { if(l==r) return l; int mid=(l+r)>>1; if(sum[lson[x]]) return getmin(lson[x],l,mid); else return getmin(rson[x],mid+1,r); } int main() { // setIO("input"); s1.init(); s2.init(); int n,i,j,m; scanf("%d%s",&n,str+1); for(i=1;i<=n;++i) s1.extend(str[i]-‘0‘); for(i=n;i>=1;--i) { s2.extend(str[i]-‘0‘); int lst=s2.last; update(rt[lst],1,n+1,i); } // 线段树合并 for(i=1;i<=s2.tot;++i) ++tax[s2.len[i]]; for(i=1;i<=s2.tot;++i) tax[i]+=tax[i-1]; for(i=1;i<=s2.tot;++i) rk[tax[s2.len[i]]--]=i; for(i=s2.tot;i>1;--i) { int u=rk[i]; int ff=s2.pre[u]; rt[ff]=merge(rt[ff],rt[u]); } scanf("%d",&m); for(i=1;i<=m;++i) { scanf("%s",str+1); int length=strlen(str+1),pp=1,len=0,mn=n+1; for(j=1;j<=length;++j) { if(s1.ch[pp][str[j]-‘0‘]) pp=s1.ch[pp][str[j]-‘0‘],++len; else break; } for(pp=1,j=len;j>=1;--j) pp=s2.ch[pp][str[j]-‘0‘]; if(len==length) mn=getmin(rt[pp],1,n+1); LL ans=0ll; while(pp!=1) { LL tmp=query(rt[pp],1,n+1,1,mn); ans+=tmp*(min(len,s2.len[pp])-s2.len[s2.pre[pp]]); pp=s2.pre[pp]; } ans+=1ll*(mn-1ll); printf("%lld\n",ans); } return 0; }
原文:https://www.cnblogs.com/guangheli/p/11779123.html