题目链接:https://www.luogu.com.cn/problem/P3449
Johnny 喜欢玩文字游戏。
他写下了 \(n\) 个回文串,随后将这些串两两组合,合并成一个新串。容易看出,一共会有 \(n^2\) 个新串。
两个串组合时顺序是任意的,即 a
和 b
可以组合成 ab
和 ba
,另外自己和自己组合也是允许的。
现在他想知道这些新串中有多少个回文串,你能帮帮他吗?
不难发现,一个长度较短的串\(s1\)与长度较长的串\(s2\)拼接起来能是一个回文串,必须满足以下两点:
那么我们把所有的串用\(vector\)存起来,然后全部插入一棵\(Trie\)里,再枚举每一个串,将它倒过来在\(Trie\)里寻找。
如果找到某一个时刻有一个串与这个倒过来的串相匹配,那么说明这两个串满足了条件\((1)\)。如果再满足剩余部分是一个回文串,那么这两个串就对答案有贡献。
判断是否是回文串用\(hash\)就可以了。正序做一遍\(hash\),倒序做一遍\(hash\),然后判断一个区间是否回文就把这个区间的前半部分的正序\(hash\)值与后半部分的倒序\(hash\)值相比较即可。
由于这道题给出的串全部回文,所以每一个可行方案倒过来也是一种可行方案,所以\(ans\)要乘\(2\),又因为我们没有计算自己与自己匹配,所以最终答案是\(2ans+n\)。
时间复杂度\(O(n)\)
\(update:\)其实可以不用倒过来的。。。因为给出的串都是回文串,倒过来的前缀其实还是原来的前缀。。。当时傻掉了。
#include <string>
#include <cstdio>
#include <vector>
#include <cstring>
#include <iostream>
#include <algorithm>
using namespace std;
typedef long long ll;
typedef unsigned long long ull;
const int N=2000010;
const ull base=131;
int n,maxn,len[N];
ll ans;
char Ch;
ull hash[N][3],power[N];
vector<string> ch;
string s;
struct Trie
{
int tot,trie[N][27];
bool end[N];
void find(int i,int j)
{
int p=1;
for (;j>=0;j--)
{
if (end[p]==1)
{
int mid=(1+j+1)>>1,f=j&1;
if (hash[mid][1]==hash[mid+f][2]-hash[j+2][2]*power[j+2-mid-f]) ans++;
}
if (trie[p][ch[i][j]-'a'+1]) p=trie[p][ch[i][j]-'a'+1];
else return;
}
}
void update(int i)
{
int p=1;
for (register int j=0;j<len[i];j++)
{
if (!trie[p][ch[i][j]-'a'+1])
trie[p][ch[i][j]-'a'+1]=++tot;
p=trie[p][ch[i][j]-'a'+1];
}
end[p]=1;
}
}trie;
int main()
{
scanf("%d",&n);
ch.push_back("WYC AK IOI");
for (register int i=1;i<=n;i++)
{
scanf("%d",&len[i]);
s="";
for (register int j=1;j<=len[i];j++)
{
while (Ch=getchar()) if (Ch>='a' && Ch<='z') break;
s+=Ch;
}
ch.push_back(s);
maxn=max(maxn,len[i]);
}
power[0]=1;
for (register int i=1;i<=maxn;i++)
power[i]=power[i-1]*base;
trie.tot=1;
for (register int i=1;i<=n;i++)
trie.update(i);
for (int i=1;i<=n;i++)
{
hash[0][1]=hash[len[i]+1][2]=0;
for (register int j=1;j<=len[i];j++)
hash[j][1]=hash[j-1][1]*base+ch[i][j-1]-'a'+1;
for (register int j=len[i];j>=1;j--)
hash[j][2]=hash[j+1][2]*base+ch[i][j-1]-'a'+1;
trie.find(i,len[i]-1);
}
printf("%lld\n",ans*2+n);
return 0;
}
原文:https://www.cnblogs.com/stoorz/p/12231879.html