过于神仙
首先有一个非常简单的暴力,就是枚举每一个点做为树根,从这个点开始扩展,直接在\(\operatorname{SAM}\)的\(\operatorname{DAG}\)上跑,一遍跑一遍统计就好了,对于一棵\(sze\)个节点的树复杂度显然是\(O(sze^2)\)
发现这又是一个树上路径问题,我们考虑一个用点分的高级暴力,对于一个分治中心\(x\),我们考虑如何拼接路径,我们设\(end_i\)表示有多少条\(u->x\)的路径对应的字符串在\(i\)位置结尾,\(beg_i\)表示有所少条\(x->v\)的路径对应的字符串从\(i\)开始,那么答案显然就是\(\sum_{i=1}^mbeg_i\times end_i\)
考虑如何求\(beg_i\)和\(end_i\),我们发现\(end_i\)本质上就是\(u->x\)这些路径在\(parent\)树上的\(endpos\),但是这里需要往前面添加字符,于是我们考虑在反串上跑,变成向后添加字符,同时\(endpos\)也变成了\(beginpos\),而\(beg_i\)本质上就是\(beginpos\)
于是现在问题变成了求\(beginpos\),后缀树能够办到这一点
众所周知,一个字符串的后缀树等价于把这个字符串的所有后缀插到一棵\(trie\)里,于是我们是可以直接利用后缀树来匹配字符串的;另一个广为人知的结论就是,一个串的后缀树等价于其反串的\(parent\)树,于是我们可以利用\(\operatorname{SAM}\)快速构建后缀树,同时我们还处理出后缀树上每一个节点到其所有儿子经过的转移边是那一条,这个利用\(\operatorname{SAM}\)也可以做到
于是我们就可以直接在后缀树上匹配,经过的点打标记,最后把所有标记下推到叶子就可以了
这个点分暴力的复杂度是\(O(nm+n\operatorname{logn})\),相比\(O(n^2)\)的暴力适用于处理联通块更大的情况
于是我们在点分的时候,如果联通块大小小于\(\sqrt{n}\),我们直接跑\(O(sze^2)\)的暴力,否则就跑这个点分暴力,这样复杂度就是\(O((n+m)\sqrt{n})\)
代码
#include<bits/stdc++.h>
#define re register
#define LL long long
#define max(a,b) ((a)>(b)?(a):(b))
#define min(a,b) ((a)<(b)?(a):(b))
inline int read() {
char c=getchar();int x=0;while(c<'0'||c>'9') c=getchar();
while(c>='0'&&c<='9') x=(x<<3)+(x<<1)+c-48,c=getchar();return x;
}
const int maxn=1e5+5;
struct E{int v,nxt;}e[maxn];
int head[maxn>>1],sum[maxn>>1],mx[maxn>>1],vis[maxn>>1];
int T,rt,n,m,B,num;LL ans;
char S[maxn>>1],a[maxn>>1];
struct SAM {
int len[maxn],t[maxn>>1],fa[maxn],pos[maxn],nxt[maxn][26],g[maxn];
int son[maxn][26],s[maxn>>1],tax[maxn>>1],A[maxn],sz[maxn];
int lst,cnt;
inline void ins(int c,int o) {
int p=++cnt,f=lst;lst=p;
len[p]=len[f]+1,sz[p]=1,pos[p]=o;t[o]=p;
while(f&&!son[f][c]) son[f][c]=p,f=fa[f];
if(!f) {fa[p]=1;return;}
int x=son[f][c];
if(len[f]+1==len[x]) {fa[p]=x;return;}
int y=++cnt;
len[y]=len[f]+1,fa[y]=fa[x],fa[x]=fa[p]=y;
for(re int i=0;i<26;i++) son[y][i]=son[x][i];
while(f&&son[f][c]==x) son[f][c]=y,f=fa[f];
}
inline void build() {
lst=cnt=1;
for(re int i=1;i<=m;i++) ins(s[i],i);
for(re int i=1;i<=cnt;i++) tax[len[i]]++;
for(re int i=1;i<=m;i++) tax[i]+=tax[i-1];
for(re int i=1;i<=cnt;i++) A[tax[len[i]]--]=i;
for(re int i=cnt;i;--i) {
int x=A[i];
sz[fa[x]]+=sz[x];
if(!pos[fa[x]]) pos[fa[x]]=pos[x];
nxt[fa[x]][s[pos[x]-len[fa[x]]]]=x;
}
}
inline void clear() {
for(re int i=1;i<=cnt;i++) g[i]=0;
}
inline void update() {
for(re int i=1;i<=cnt;i++) g[A[i]]+=g[fa[A[i]]];
}
void match(int x,int fa,int now,int l) {
if(l==len[now]) now=nxt[now][a[x]];
else if(s[pos[now]-l]!=a[x]) now=0;
if(!now) return;g[now]++;l++;
for(re int i=head[x];i;i=e[i].nxt) {
if(vis[e[i].v]||e[i].v==fa) continue;
match(e[i].v,x,now,l);
}
}
}p[2];
inline void add(int x,int y) {
e[++num].v=y;e[num].nxt=head[x];head[x]=num;
}
void getroot(int x,int fa) {
sum[x]=1;mx[x]=0;
for(re int i=head[x];i;i=e[i].nxt) {
if(vis[e[i].v]||e[i].v==fa) continue;
getroot(e[i].v,x);sum[x]+=sum[e[i].v];
mx[x]=max(mx[x],sum[e[i].v]);
}
mx[x]=max(mx[x],T-sum[x]);
if(mx[x]<mx[rt]) rt=x;
}
void calc(int x,int now,int fa) {
now=p[0].son[now][a[x]];
if(!now) return;
ans+=p[0].sz[now];
for(re int i=head[x];i;i=e[i].nxt) {
if(vis[e[i].v]||e[i].v==fa) continue;
calc(e[i].v,now,x);
}
}
void solve(int x,int fa) {
calc(x,1,0);
for(re int i=head[x];i;i=e[i].nxt) {
if(vis[e[i].v]||e[i].v==fa) continue;
solve(e[i].v,x);
}
}
void getdis(int x,int fa) {
p[0].clear(),p[1].clear();
p[0].match(x,0,1,0);p[0].update();
p[1].match(x,0,1,0);p[1].update();
for(re int i=1;i<=m;i++)
ans+=1ll*p[0].g[p[0].t[i]]*p[1].g[p[1].t[m-i+1]];
}
void del(int x,int fa) {
p[0].clear(),p[1].clear();
p[0].match(x,0,p[0].nxt[1][a[fa]],1);
p[1].match(x,0,p[1].nxt[1][a[fa]],1);
p[0].update(),p[1].update();
for(re int i=1;i<=m;i++)
ans-=1ll*p[0].g[p[0].t[i]]*p[1].g[p[1].t[m-i+1]];
}
void dfs(int x) {
if(sum[x]<=B) {solve(x,0);return;}
getdis(x,0);vis[x]=1;
for(re int i=head[x];i;i=e[i].nxt)
if(!vis[e[i].v]) del(e[i].v,x);
for(re int i=head[x];i;i=e[i].nxt) {
if(vis[e[i].v]) continue;
T=sum[e[i].v];rt=0;getroot(e[i].v,0);dfs(rt);
}
}
int main() {
n=read(),m=read();B=std::ceil(std::sqrt(n));
for(re int x,y,i=1;i<n;i++)
x=read(),y=read(),add(x,y),add(y,x);
scanf("%s",a+1),scanf("%s",S+1);
for(re int i=1;i<=n;i++) a[i]-='a';
for(re int i=1;i<=m;i++) S[i]-='a';
for(re int i=1;i<=m;i++) p[0].s[i]=S[i],p[1].s[m-i+1]=S[i];
p[0].build(),p[1].build();
mx[0]=n+1,T=n,getroot(1,0);dfs(rt);
std::cout<<ans;
return 0;
}
原文:https://www.cnblogs.com/asuldb/p/11299017.html