题意:给你n个母串,m个匹配串,让你求出对于每个母串 所有匹配串出现的次数和。
思路:ac自动机模板题,加入一个数组val[i] 表示以i节点结束时匹配了几个匹配串
代码:
#include <algorithm> #include <iostream> #include <sstream> #include <cstdlib> #include <cstring> #include <iomanip> #include <cstdio> #include <string> #include <bitset> #include <vector> #include <queue> #include <stack> #include <cmath> #include <list> #include <map> #include <set> #define sss(a,b,c) scanf("%d%d%d",&a,&b,&c) #define mem1(a) memset(a,-1,sizeof(a)) #define mem(a) memset(a,0,sizeof(a)) #define ss(a,b) scanf("%d%d",&a,&b) #define s(a) scanf("%d",&a) #define p(a) printf("%d\n", a) #define INF 0x3f3f3f3f #define w(a) while(a) #define PI acos(-1.0) #define LL long long #define eps 10E-9 #define N 200000+20 #define mod 1000000007 const int SIGMA_SIZE=26; const int MAXN=100010; const int MAXNODE=600010; using namespace std; void mys(int& res) { int flag=0; char ch; while(!(((ch=getchar())>='0'&&ch<='9')||ch=='-')) if(ch==EOF) res=INF; if(ch=='-') flag=1; else if(ch>='0'&&ch<='9') res=ch-'0'; while((ch=getchar())>='0'&&ch<='9') res=res*10+ch-'0'; res=flag?-res:res; } void myp(int a) { if(a>9) myp(a/10); putchar(a%10+'0'); } /*************************THE END OF TEMPLATE************************/ int n,m; LL ans; string str1[MAXN],str2[MAXN]; struct AC{ int ch[MAXNODE][SIGMA_SIZE],f[MAXNODE],sz; LL val[MAXNODE]; void init(){ memset(ch[0],0,sizeof(ch[0])); val[0]=0; sz=1; } int idx(char c){ return c-'a'; } void insert(string s){ int u=0,len=s.size(); for(int i=0;i<len;i++){ int c=idx(s[i]); if(!ch[u][c]){ memset(ch[sz],0,sizeof(ch[sz])); val[sz]=0; ch[u][c]=sz++; } u=ch[u][c]; } val[u]++; } void get_fail(){ queue<int> q; f[0]=0; for(int c=0;c<SIGMA_SIZE;c++){ int u=ch[0][c]; if(u){ f[u]=0; q.push(u); } } while(!q.empty()){ int r=q.front(); q.pop(); for(int c=0;c<SIGMA_SIZE;c++){ int u=ch[r][c]; if(!u){ ch[r][c]=ch[f[r]][c]; continue; } q.push(u); f[u]=ch[f[r]][c]; val[u]+=val[f[u]]; } } } void find(string T){ int j=0; int len=T.size(); for(int i=0;i<len;i++){ int c=idx(T[i]); j=ch[j][c]; ans+=val[j]; } } }ac; int main(){ int t; s(t); while(t--){ ss(n, m); for(int i=1;i<=n;i++) cin>>str1[i]; ac.init(); for(int i=1;i<=m;i++){ cin>>str2[i]; ac.insert(str2[i]); } ac.get_fail(); for(int i=1;i<=n;i++){ ans=0; ac.find(str1[i]); printf("%I64d\n",ans); } } return 0; }
版权声明:本文为博主原创文章,未经博主允许不得转载。
原文:http://blog.csdn.net/bigsungod/article/details/47659289