给你一颗点上有字符的树,问一个给定的字符串是否是这棵树上的两点的路径。
树分治的思想就是每次找重心,重心下的子问题分解去做,然后就是合并了。合并的时候用一个总的set<pair<len,hash>> 去存从根节点往下走的长度以及对应的hash值,判的时候只需要看下是否已经存在 m-len,以及对应的前缀(或者后缀)的哈希值,然后再加进来。
两个优化的点是,1是递归解子问题的时候如果子树规模小于要给的字符串可以不用递归下去。2是存pair的时候只需要存前缀的pair以及后缀的pair,其它的都不用存。
加了这些优化之后我的程度勉强能在4000ms内通过。看来我树分治的写法还是太慢了。。
#pragma warning(disable:4996) #include <iostream> #include <cstdio> #include <cstring> #include <vector> #include <cmath> #include <algorithm> #include <string> #include <cstdlib> #include <ctime> #include <map> #include <set> using namespace std; #define maxn 10010 #define ll long long #define MP make_pair vector<int> G[maxn]; bool centroid[maxn]; int ssize[maxn]; char val[maxn]; char tar[maxn]; set<pair<ll, ll> > sta; set<pair<ll, ll> >::iterator it; int n, m; ll mod_num; ll mod; ll xpow[maxn]; ll pre[maxn]; ll post[maxn]; ll xorr(ll x, ll y) { return (x*mod_num%mod + y) % mod; } int compute_ssize(int v, int p) { int c = 1; for (int i = 0; i<G[v].size(); ++i){ int w = G[v][i]; if (w == p || centroid[w]) continue; c += compute_ssize(G[v][i], v); } ssize[v] = c; return c; } pair<int, int> search_centroid(int v, int p, int t) { pair<int, int> res = make_pair(INT_MAX, -1); int s = 1, m = 0; for (int i = 0; i < G[v].size(); ++i){ int w = G[v][i]; if (w == p || centroid[w]) continue; res = min(res, search_centroid(w, v, t)); m = max(m, ssize[w]); s += ssize[w]; } m = max(m, t - s); res = min(res, make_pair(m, v)); return res; } void enumerate_mul(int v, int p, pair<ll, ll> d, set<pair<ll, ll> > &ds) { if (!ds.count(d)) ds.insert(d); for (int i = 0; i < G[v].size(); ++i){ int w = G[v][i]; if (w == p || centroid[w]) continue; enumerate_mul(w, v, MP(d.first + 1, xorr(d.second, val[w])), ds); } } bool judge(pair<ll, ll> x, const set<pair<ll, ll> >& tds) { if (pre[x.first] == x.second){ return tds.count(MP(m - x.first, post[x.first + 1])); } if (post[m - x.first + 1] == x.second){ return tds.count(MP(m - x.first, pre[m - x.first])); } return false; } bool solve(int v) { compute_ssize(v, -1); int s = search_centroid(v, -1, ssize[v]).second; centroid[s] = true; for (int i = 0; i<G[s].size(); ++i){ if (centroid[G[s][i]]) continue; if (ssize[G[s][i]] < m){ continue; } if (solve(G[s][i])) { return true; } } sta.clear(); sta.insert(MP(1, val[s])); if (m == 1 && val[s] == tar[1]){ return true; } set<pair<ll, ll> > tds; for (int i = 0; i<G[s].size(); ++i){ if (centroid[G[s][i]]) continue; tds.clear(); enumerate_mul(G[s][i], s, MP(1, val[G[s][i]]), tds); it = tds.begin(); while (it != tds.end()){ if (judge(*it, sta)){ return true; } ++it; } it = tds.begin(); while (it != tds.end()){ ll one = (*it).first; ll two = (*it).second; pair<ll, ll> vv; if (one > m){ ++it; continue; } if (pre[one] != two && post[m - one + 1] != two){ ++it; continue; } if (tar[one + 1] != val[s] && tar[m - one] != val[s]){ ++it; continue; } vv.first = one + 1; vv.second = (xpow[one] * val[s] % mod + two) % mod; if (!sta.count(vv)){ sta.insert(vv); } ++it; } } centroid[s] = false; return false; } ll haha[3] = { 37, 23, 53 }; ll tata[2] = { 1000000007, 1000010009 }; int main() { int T; cin >> T; int ca = 0; while (T--) { cin >> n; for (int i = 0; i <= n; ++i) G[i].clear(); int ui, vi; for (int i = 0; i < n - 1; ++i){ scanf("%d%d", &ui, &vi); G[ui].push_back(vi); G[vi].push_back(ui); } memset(centroid, 0, sizeof(centroid)); scanf("%s", val + 1); scanf("%s", tar + 1); mod_num = haha[rand() % 3]; mod = tata[rand() % 2]; xpow[0] = 1; for (int i = 1; i <= n; ++i){ xpow[i] = xpow[i - 1] * mod_num%mod; } pre[0] = 0; m = strlen(tar + 1); for (int i = 1; i <= m; ++i){ pre[i] = (xpow[i - 1] * tar[i] % mod + pre[i - 1]) % mod; } post[m + 1] = 0; for (int i = m; i >= 1; --i){ post[i] = (tar[i] * xpow[m - i] % mod + post[i + 1]) % mod; } bool flag = solve(1); if (flag){ printf("Case #%d: Find\n", ++ca); } else{ printf("Case #%d: Impossible\n", ++ca); } } return 0; }
原文:http://www.cnblogs.com/chanme/p/4841239.html