先一道一道题慢慢补上,
1009.题意,一棵N(N<=50000)个节点的树,每个节点上有一个字母值,给定一个串S0(|S0| <=30),q个询问,(q<=50000),每次询问经过两个点u,v之间的路径上的字母构成字符串S,问S0在S中作为序列出现了多少次。
分析:对于每次询问需要知道其LCA,设w = LCA(u, v),可以用tarjan处理出来,或者倍增法也行。然后w会将S0分成两部分,记做S1,S2,然后分别考虑S1在u->w的路径出现次数,以及S2在v->w出现的次数。
S1(x) = S0[1....x],1<=x<=|S0|.
以S1为例,需要预处理出来u到根节点的路径上对应S0[i,j]序列出现的次数,设为dp1[u][i][j],然后S1(x)在u->w出现的次数记为t1[x],那么
t1[x] = dp1[u][1][x] - sum{t1[a-1]*dp1[fa[w]][a][x]} (1<=a<=x)
而dp1[u][1][x] 可以在tarjan求LCA时候预处理出来,转移dp1[u][i][j] = dp1[fa[u]][i][j] + (nd[u] == S0[i])*dp1[fa[u]][i+1][j];
对于S2也是可以同样考虑的。
注意:占用内存很大,需要用16位节省内存,dfs的话还需要扩栈,实现时发现不能随便用unsigned short,因为变成负数的话会相当于mod 2^16这个是会导致WA的。
代码:
1 #pragma comment(linker, "/STACK:16777216") 2 #include <cstdio> 3 #include <iostream> 4 #include <cstring> 5 #include <string> 6 #include <cstdlib> 7 #include <algorithm> 8 #include <vector> 9 #include <queue> 10 #include <map> 11 #include <set> 12 #define in freopen("solve_in.txt", "r", stdin); 13 #define Rep(i, base, n) for(int i = (base); i < n; i++) 14 #define REP(i, n) for(int i = 0; i < (n); i++) 15 #define VREP(i, n, base) for(int i = (n); i >= (base); i--) 16 #define SET(a, n) memset(a, (n), sizeof(a)); 17 #define pb push_back 18 #define mp make_pair 19 20 using namespace std; 21 typedef vector<unsigned short> VI; 22 typedef pair<unsigned short, unsigned short> PII; 23 typedef vector<PII> VII; 24 typedef long long LL; 25 26 const int maxn = 50000 + 10; 27 const int maxm = 33; 28 const int M = 10007 ; 29 30 short dp1[maxn][maxm][maxm], dp2[maxn][maxm][maxm]; 31 bool vis[maxn]; 32 int len; 33 char s[maxm], nd[maxn]; 34 unsigned short pa[maxn]; 35 36 VI g[maxn]; 37 unsigned short lca[maxn]; 38 VII que[maxn]; 39 VII qq; 40 41 unsigned short findset(unsigned short x) { 42 return x == pa[x] ? x : pa[x] = findset(pa[x]); 43 } 44 45 void init(int n) { 46 qq.clear(); 47 REP(i, n+1) { 48 que[i].clear(); 49 g[i].clear(); 50 vis[i] = false; 51 REP(j, maxm) REP(k, maxm) { 52 dp1[i][j][k] = 0, dp2[i][j][k] = 0; 53 if(j > k) 54 dp1[i][j][k] = 1; 55 if(j < k) 56 dp2[i][j][k] = 1; 57 } 58 } 59 } 60 int n, m; 61 short t1[maxm], t2[maxm]; 62 void tarjan(int u, int fa) { 63 pa[u] = u; 64 if(!fa) { 65 Rep(i, 1, len+1) 66 if(s[i] == nd[u]) { 67 dp1[u][i][i] = dp2[u][i][i] = 1; 68 } 69 } else { 70 int v = u; 71 u = fa; 72 Rep(i, 1, len+1) Rep(j, 1, len+1) { 73 if(i >= j) { 74 dp2[v][i][j] = (dp2[v][i][j] + dp2[u][i][j] + (nd[v] == s[i])*(dp2[u][i-1][j]))%M; 75 } 76 if(dp2[v][i][j] < 0) 77 dp2[v][i][j] += M; 78 if(i <= j) { 79 dp1[v][i][j] = (dp1[v][i][j] + dp1[u][i][j] + (nd[v] == s[i])*(dp1[u][i+1][j]))%M; 80 } 81 if(dp1[v][i][j] < 0) 82 dp1[v][i][j] += M; 83 } 84 u =v; 85 } 86 vis[u] = true; 87 REP(i, g[u].size()) { 88 int v = g[u][i]; 89 if(v == fa) 90 continue; 91 tarjan(v, u); 92 pa[v] = u; 93 } 94 REP(i, que[u].size()) { 95 int v = que[u][i].first; 96 int id = que[u][i].second; 97 if(!vis[v]) 98 continue; 99 lca[id] = findset(v); 100 } 101 } 102 void solve() { 103 REP(ii, m) { 104 int w = lca[ii]; 105 int u = qq[ii].first; 106 int v = qq[ii].second; 107 if(u == v) { 108 printf("%d\n", len == 1 && s[1] == nd[u]); 109 continue; 110 } 111 SET(t1, 0); 112 SET(t2, 0); 113 t1[0] = t2[len+1] = 1; 114 if(w != u) { 115 Rep(i, 1, len+1) { 116 int tmp = 0; 117 Rep(x, 1, i+1) { 118 tmp = (tmp + ((LL)t1[x-1]*dp1[w][x][i])%M)%M; 119 } 120 t1[i] = (dp1[u][1][i]-tmp)%M; 121 if(t1[i] < 0) 122 t1[i] += M; 123 } 124 } 125 if(w != v) { 126 VREP(i, len, 1) { 127 int tmp = 0; 128 VREP(x, len, i) { 129 tmp = (tmp + ((LL)t2[x+1]*dp2[w][x][i])%M)%M; 130 } 131 t2[i] = (dp2[v][len][i] - tmp)%M; 132 if(t2[i] < 0) 133 t2[i] += M; 134 } 135 } 136 int ans = 0; 137 REP(i, len+1) { 138 if(s[i] == nd[w]) 139 ans = (ans + ((LL)t1[i-1]*t2[i+1])%M)%M; 140 ans = (ans + ((LL)t1[i]*t2[i+1])%M)%M; 141 if(ans < 0) 142 ans += M; 143 } 144 if(ans < 0) 145 ans += M; 146 printf("%d\n", ans); 147 } 148 } 149 int main() { 150 151 int T; 152 for(int t = scanf("%d", &T); t <= T; t++) { 153 scanf("%d%d", &n, &m); 154 init(n); 155 REP(i, n-1) { 156 int u, v; 157 scanf("%d%d", &u, &v); 158 g[u].pb(v); 159 g[v].pb(u); 160 } 161 s[0] = ‘\0‘; 162 scanf("%s%s", nd+1, s+1); 163 len = strlen(s+1); 164 REP(i, m) { 165 int u, v; 166 scanf("%d%d", &u, &v); 167 qq.pb(mp(u, v)); 168 que[u].pb(mp(v, i)); 169 que[v].pb(mp(u, i)); 170 } 171 tarjan(1, 0); 172 solve(); 173 } 174 return 0; 175 }
原文:http://www.cnblogs.com/rootial/p/3899109.html