题目链接: http://poj.org/problem?id=2763
题意: 第一行输入 n, q, s 分别为树的顶点个数, 询问/修改个数, 初始位置. 接下来 n - 1 行形如 x, y, w 的输入为点 x, y 之间连边且边权为 w.
接下来 q 行输入, 若输入形式为 1 x y 则为将点 x 的权值修改为 y , 若输入形式为 0 x 则询问 s 到 x 的最短距离为多少. 上一组的 x 为下一组的 s.
思路: 若去掉修改边权部分, 则为一个 lca 模板题. 对于修改边权, 直接暴力向下修改 dis 数组可能会 tle . 可以将树映射到线段树上, 那样修改时间可降为为 log(n).
具体操作为, dfs 出每点的入点时间戳 l [MAXN] 和出点时间戳 r [MAXN] , 并记录 dfs 顺序 sol [MAXN] .
将 dis [sol [MAXN] ] 放入 sum 数组建线段树.
那么 x 到根结点的距离为: query(l[x], l[x], 1, n, 1) .
将 x 点的权值由 w 修改到 w‘ 操作为: updata(l[x], r[x], w - w‘, 1, n, 1) .
代码:
1 #include <iostream> 2 #include <stdio.h> 3 #include <math.h> 4 #include <string.h> 5 #define lson l, mid, rt << 1 6 #define rson mid + 1, r, rt << 1 | 1 7 using namespace std; 8 9 const int MAXN = 3e5 + 10; 10 struct node{ 11 int v, w, next; 12 node(){}; 13 node(int V, int W, int NEXT) : v(V), w(W), next(NEXT){}; 14 }edge[MAXN << 1]; 15 16 int dp[MAXN << 1][20]; 17 int a[MAXN], b[MAXN], w[MAXN]; 18 int sum[MAXN << 2], add[MAXN << 2]; 19 int l[MAXN], r[MAXN], sol[MAXN], cnt; 20 int head[MAXN], dis[MAXN], dep[MAXN], ip, indx; 21 int first[MAXN], ver[MAXN << 1], deep[MAXN << 1]; 22 23 void init(void){ 24 memset(head, -1, sizeof(head)); 25 ip = 0; 26 cnt = 0; 27 indx = 0; 28 } 29 30 void addedge(int u, int v, int w){ 31 edge[ip] = node(v, w, head[u]); 32 head[u] = ip++; 33 } 34 35 void dfs(int u, int pre, int h){ 36 dep[u] = h; 37 ver[++indx] = u; 38 deep[indx] = h; 39 first[u] = indx; 40 for(int i = head[u]; i != -1; i = edge[i].next){ 41 int v = edge[i].v; 42 if(v == pre) continue; 43 dis[v] = dis[u] + edge[i].w; 44 dfs(v, u, h + 1); 45 ver[++indx] = u; 46 deep[indx] = h; 47 } 48 } 49 50 51 void ST(int n){ 52 for(int i = 1; i <= n; i++){ 53 dp[i][0] = i; 54 } 55 for(int j = 1; (1 << j) <= n; j++){ 56 for(int i = 1; i + (1 << j) - 1 <= n; i++){ 57 int x = dp[i][j - 1], y = dp[i + (1 << (j - 1))][j - 1]; 58 dp[i][j] = deep[x] < deep[y] ? x : y; 59 } 60 } 61 } 62 63 int RMQ(int l, int r){ 64 int len = log2(r - l + 1); 65 int x = dp[l][len], y = dp[r - (1 << len) + 1][len]; 66 return deep[x] < deep[y] ? x : y; 67 } 68 69 int LCA(int x, int y){ 70 int l = first[x], r = first[y]; 71 if(l > r) swap(l, r); 72 int pos = RMQ(l, r); 73 return ver[pos]; 74 } 75 76 void dfs2(int x, int pre){ 77 sol[++cnt] = x; 78 l[x] = cnt; 79 for(int i = head[x]; i != -1; i = edge[i].next){ 80 int v = edge[i].v; 81 if(v != pre) dfs2(v, x); 82 } 83 r[x] = cnt; 84 } 85 86 void push_up(int rt){ 87 sum[rt] = sum[rt << 1] + sum[rt << 1 | 1]; 88 } 89 90 void push_down(int rt, int m){ 91 if(add[rt]){ 92 add[rt << 1] += add[rt]; 93 add[rt << 1 | 1] += add[rt]; 94 sum[rt << 1] += (m - (m >> 1)) * add[rt]; 95 sum[rt << 1 | 1] += (m >> 1) * add[rt]; 96 add[rt] = 0; 97 } 98 } 99 100 void build(int l, int r, int rt){ 101 add[rt] = 0; 102 if(l == r){ 103 sum[rt] = dis[sol[l]]; 104 return; 105 } 106 int mid = (l + r) >> 1; 107 build(lson); 108 build(rson); 109 push_up(rt); 110 } 111 112 void updata(int L, int R, int key, int l, int r, int rt){ 113 if(L <= l && R >= r){ 114 sum[rt] += (r - l + 1) * key; 115 add[rt] += key; 116 return; 117 } 118 push_down(rt, r - l + 1); 119 int mid = (l + r) >> 1; 120 if(L <= mid) updata(L, R, key, lson); 121 if(R > mid) updata(L, R, key, rson); 122 push_up(rt); 123 } 124 125 int query(int L, int R, int l, int r, int rt){ 126 if(L <= l && R >= r) return sum[rt]; 127 push_down(rt, r - l + 1); 128 int mid = (l + r) >> 1; 129 int ans = 0; 130 if(L <= mid) ans += query(L, R, lson); 131 if(R > mid) ans += query(L, R, rson); 132 return ans; 133 } 134 135 int main(void){ 136 int n, q, s, x, y, z, op; 137 while(~scanf("%d%d%d", &n, &q, &s)){ 138 init(); 139 for(int i = 1; i < n; i++){ 140 scanf("%d%d%d", &x, &y, &z); 141 addedge(x, y, z); 142 addedge(y, x, z); 143 a[i] = x; 144 b[i] = y; 145 w[i] = z; 146 } 147 dis[1] = 0; 148 dfs(1, -1, 1); 149 ST(indx); 150 dfs2(1, -1); 151 build(1, cnt, 1); 152 while(q--){ 153 scanf("%d", &op); 154 if(!op){ 155 scanf("%d", &y); 156 int lca = LCA(s, y); 157 int sol1 = query(l[s], l[s], 1, cnt, 1); 158 int sol2 = query(l[y], l[y], 1, cnt, 1); 159 int sol3 = query(l[lca], l[lca], 1, cnt, 1); 160 printf("%d\n", sol1 + sol2 - 2 * sol3); 161 s = y; 162 }else{ 163 scanf("%d%d", &x, &y); 164 int u = a[x], v = b[x], add = y - w[x]; 165 int cnt2 = dep[u] > dep[v] ? u : v; 166 updata(l[cnt2], r[cnt2], add, 1, cnt, 1); 167 w[x] = y;//注意记录修改后的权值 168 } 169 } 170 } 171 return 0; 172 }
原文:http://www.cnblogs.com/geloutingyu/p/7224891.html