给一个字符串 \(S\) ,再给 \(m\) 个字符串 \(T\) ,问 \(T\) 有多少连续非空字串 不是 \(S\) 的连续非空字串。
\[
|S| \le 5\times10^5 , \sum{T} \le10^6
\]
我们可以把题目转化为 \(T\) 有多少本质不同的子串不在 \(S\) 中出现过,我们对两个字符串都建 \(SAM\) ,首先求出 \(T\) 的每一个前缀与 \(S\) 的最大匹配长度 \(l_i\) (用 \(S\) 的 \(SAM\) 求最长公共前缀就可以求出 \(l_i\) ),再在 \(T\) 的 \(SAM\) 上求答案。我们定义节点 \(i\) 代表的集合在 \(T\) 中最先出现的位置为 \(pos_i\) ,则
\[
ans= \sum_{i=1}^{tot}{max(0,len[i]-max(len[link[i]],l_{pos_i}))}
\]
#include<iostream>
#include<cstdio>
#include<algorithm>
#include<cstring>
#define ll long long
using namespace std;
int read()
{
int k=0,f=1;char c=getchar();
for(;!isdigit(c);c=getchar()) if(c=='-') f=-1;
for(;isdigit(c);c=getchar()) k=k*10+c-'0';return k*f;
}
const int N=2000055;
int n,m,tot,last,link[N],len[N],ch[N][26];
int Len[N];
char s[N],t[N];
void init()
{
tot=last=0;
link[0]=-1;len[0]=0;
}
void extend(int c)
{
int p=last,cur=++tot;
len[cur]=len[p]+1;
for(;p!=-1&&!ch[p][c];p=link[p]) ch[p][c]=cur;
if(p==-1) link[cur]=0;
else
{
int q=ch[p][c];
if(len[p]+1==len[q]) link[cur]=q;
else
{
int nq=++tot;
memcpy(ch[nq],ch[q],sizeof(ch[q]));
len[nq]=len[p]+1;link[nq]=link[q];
link[q]=link[cur]=nq;
for(;p!=-1&&ch[p][c]==q;p=link[p])
ch[p][c]=nq;
}
}
last=cur;
}
struct sam
{
int tot=0,last=0,link[N],len[N],ch[N][26],pos[N];
void init()
{
for(int i=0;i<=tot;i++)
pos[i]=link[i]=len[i]=0,memset(ch[i],0,sizeof(ch[i]));
tot=last=0;
link[0]=-1;len[0]=0;
}
void extend(int c,int x)
{
int p=last,cur=++tot;
len[cur]=len[p]+1;pos[cur]=x;
for(;p!=-1&&!ch[p][c];p=link[p]) ch[p][c]=cur;
if(p==-1) link[cur]=0;
else
{
int q=ch[p][c];
if(len[p]+1==len[q]) link[cur]=q;
else
{
int nq=++tot;
memcpy(ch[nq],ch[q],sizeof(ch[q]));
len[nq]=len[p]+1;link[nq]=link[q];pos[nq]=pos[q];
link[q]=link[cur]=nq;
for(;p!=-1&&ch[p][c]==q;p=link[p])
ch[p][c]=nq;
}
}
last=cur;
}
ll query()
{
ll ans=0;
for(int i=1;i<=tot;i++)
ans+=max(0,len[i]-max(len[link[i]],Len[pos[i]]));;
return ans;
}
}sm;
int main()
{
int a,b;
scanf("%s",s);
n=strlen(s);
init();
for(int i=0;i<n;i++)
extend(s[i]-'a');
m=read();
for(int i=1;i<=m;i++)
{
scanf("%s",t);a=read();b=read();
n=strlen(t);
sm.init();
int now=0,l=0;ll ans=0;
for(int j=0;j<n;j++)
{
sm.extend(t[j]-'a',j);
while(now&&!ch[now][t[j]-'a']) now=link[now],l=len[now];
if(ch[now][t[j]-'a']) l++,now=ch[now][t[j]-'a'];
Len[j]=l;
}
ans=sm.query();
printf("%lld\n",ans);
}
return 0;
}
原文:https://www.cnblogs.com/waing/p/12243213.html