对于一个给定长度为N的字符串,求它的第K小子串是什么。并且给出一个t,t=0表示不同位置的相同子串算作一个,t=1表示不同位置的相同子串算作多个。
输出仅一行,为一个数字串,为第K小的子串。如果子串数目不足K个,则输出-1。
endpos[i]表示状态i的子串个数。
1、字串不可重复:endpos=1;
2、字串可重复:先将所有状态根据len排序,从后往前更新endpos,若状态i是复制出的节点,endpos[i]的值为以他为fail的状态的endpos的和,否则为和+1。
sum表示从当前状态出发能得到的子串数,所以是sum的孩子的和。即sum[i]=\(\sum\)$ sum[j]$(j=next[i][k],k=0...25)
求出两个数组后,通过dfs求出第k小的子串。如果当前状态的endpos大于等于k,终止。如果当前状态sum大于k,说明将在该状态的孩子中终止,输出当前状态字符,继续dfs;否则说明第k大子串不在这条路上,减去sum搜索下一条路径。
#include <bits/stdc++.h>
#define LL long long
#define P pair<int, int>
#define lowbit(x) (x & -x)
#define mem(a, b) memset(a, b, sizeof(a))
#define rep(i, a, n) for (int i = a; i <= n; ++i)
const int maxn = 500001;
#define mid ((l + r) >> 1)
#define lc rt<<1
#define rc rt<<1|1
using namespace std;
const LL mod = 1e9 + 7;
int T,k;
struct SAM{
int trans[maxn<<1][26], slink[maxn<<1], maxlen[maxn<<1];
// 用来求endpos
int indegree[maxn<<1], endpos[maxn<<1], rank[maxn<<1], ans[maxn<<1];
// 计算所有子串的和(0-9表示)
LL sum[maxn<<1];
int last, now, root;
inline void newnode (int v) {
maxlen[++now] = v;
mem(trans[now],0);
}
inline void extend(int c) {
newnode(maxlen[last] + 1);
int p = last, np = now;
// 更新trans
while (p && !trans[p][c]) {
trans[p][c] = np;
p = slink[p];
}
if (!p) slink[np] = root;
else {
int q = trans[p][c];
if (maxlen[p] + 1 != maxlen[q]) {
// 将q点拆出nq,使得maxlen[p] + 1 == maxlen[q]
newnode(maxlen[p] + 1);
int nq = now;
memcpy(trans[nq], trans[q], sizeof(trans[q]));
slink[nq] = slink[q];
slink[q] = slink[np] = nq;
while (p && trans[p][c] == q) {
trans[p][c] = nq;
p = slink[p];
}
}else slink[np] = q;
}
last = np;
// 初始状态为可接受状态
endpos[np] = 1;
printf("%d\n",np);
}
inline void init()
{
root = last = now = 1;
slink[root]=0;
mem(trans[root],0);
}
inline void getEndpos() {
for (int i = 1; i <= now; ++i) indegree[ maxlen[i] ]++; // 统计相同度数的节点的个数
for (int i = 1; i <= now; ++i) indegree[i] += indegree[i-1]; // 统计度数小于等于 i 的节点的总数
for (int i = 1; i <= now; ++i) rank[ indegree[ maxlen[i] ]-- ] = i; // 为每个节点编号,节点度数越大编号越靠后
// 从下往上按照slik更新
for (int i = now; i >= 1; --i) {
int x = rank[i];
if(T==1)
endpos[slink[x]] += endpos[x];
else endpos[x]=1;
}
endpos[1]=0;
for(int i=now ; i>=1 ; i--){
int x = rank[i];
sum[x]=endpos[x];
for(int j=0;j<26;j++)///后面可以接的字符
sum[x]+=sum[trans[x][j]];
}
}
void dfs(int x,int K)
{
if(K<=endpos[x]) return ;
K-=endpos[x];
for(int i=0 ; i<26 ; i++){
int p=trans[x][i];
if(p){
if(K<=sum[p]){
printf("%c",i+‘a‘);
dfs(p,K);
return ;
}
K-=sum[p];
}
}
}
}sam;
char s[maxn];
int main()
{
scanf("%s",s+1);
scanf("%d%d",&T,&k);
sam.init();
int len=strlen(s+1);
for(int i=1;i<=len;i++) printf("@%d ",i),sam.extend(s[i]-‘a‘);
sam.getEndpos();
sam.dfs(sam.root,k);
}
原文:https://www.cnblogs.com/qjy73/p/12622038.html