KD树的作用就是:给你一个点集,然后对这个点集建立一颗KD树,然后可以在logn — 根号n,的范围内查询距离 一个给定点 最近的、第k近的,前k近的点。
建立KD树不仅仅只是两维,它可以多维。假设我们建树到第dep层,那么我们就以第dep % D (D是总维数)为基准,像普通的二叉树那样,选一个分裂点 m,然后 [ l , m-1] 的点的 第dep%D维都是小于 m 这个点,[ m + 1,r ]大于这个点,(用到了nth_element 这个函数)然后往下建树。(其实有另一种方法是按照方差最大的维度为基准,但是代码量++,而且跑的可能还没有这种交替维度来的快。。。。)
然后查询给定点就是:假如到达第dep层,这里的中间节点是 m ,然后我们先对 m 和给定点的距离更新一下答案,然后假如 第 dep % D 维 中,给定点 < m ,就查左子树,反之查右子树。然后假如查了一次子树之后,假设 给定点 第 dep % D 维 和 m 点 差是deta , 那么 如果ans <= deta * deta ,(这里ans记录距离的平方),那么就不用查另一颗子树了,反之要查。这个和分治法求平面最近点对是一样的,因为另一颗子数的距离肯定比 deta * deta 来得大。
所以,KD树其实就是个暴力,给定点先往和和它接近的平面走,更新完答案再按需求查另一个子树。
而对于第k大,只需要在之前的操作中加一个优先队列,每一次更新答案变成往优先队列里面插入,维护优先队列的大小就可。
例题:hdu4347(第k近):http://acm.hdu.edu.cn/showproblem.php?pid=4347
1 #include<bits/stdc++.h> 2 using namespace std; 3 typedef long long ll; 4 const int M = 7; 5 const int N = 5e4 + 9; 6 int n,D,cmp_d,K; 7 struct Point{ 8 ll x[M],dis; 9 int son[2]; 10 void print(){ 11 for(int i = 0;i<D;++i) printf("%lld%c",x[i],i == D-1 ? ‘\n‘ : ‘ ‘); 12 } 13 bool operator < (const Point& b)const{ 14 return dis < b.dis; 15 } 16 }tr[N],ans[N],Q; 17 priority_queue<Point> pq; 18 bool cmp(Point a,Point b){ 19 return a.x[cmp_d] < b.x[cmp_d]; 20 } 21 ll distance(Point a,Point b){ 22 ll res = 0; 23 for(int i = 0; i < D;++i) res += (a.x[i] - b.x[i]) * (a.x[i] - b.x[i]); 24 return res; 25 } 26 int build(int l,int r,int now_d){ 27 if( l > r ) return 0; 28 int m = (l+r)>>1; 29 cmp_d = now_d; 30 nth_element(tr+l,tr+m,tr+r+1,cmp); 31 tr[m].son[0] = build(l,m-1,(now_d+1)%D ); 32 tr[m].son[1] = build(m+1,r,(now_d+1)%D ); 33 return m; 34 } 35 void query(int pos,int now_d){ 36 if( !pos ) return; 37 ll cur_dis = distance( tr[pos],Q ); 38 tr[pos].dis = cur_dis; 39 ll deta = tr[pos].x[now_d] - Q.x[now_d]; 40 int which = (deta < 0); 41 query(tr[pos].son[which],(now_d + 1)%D ); 42 43 if( pq.size() < K ) pq.push(tr[pos]); 44 else{ 45 if(cur_dis < pq.top().dis ){ 46 pq.pop(); 47 pq.push(tr[pos]); 48 } 49 } 50 if( pq.size() < K || pq.top().dis > deta*deta ) query(tr[pos].son[which^1],(now_d+1)%D); 51 } 52 int main(){ 53 while(~scanf("%d%d",&n,&D)){ 54 for(int i = 1;i<=n;++i){ 55 tr[i].son[0] = tr[i].son[1] = 0; 56 for(int j = 0;j<D;++j) scanf("%lld",&tr[i].x[j]); 57 } 58 int root = build(1,n,0); 59 int m; scanf("%d",&m); 60 while(m--){ 61 for(int j = 0;j<D;++j) scanf("%lld",&Q.x[j]); 62 scanf("%d",&K); 63 query(root,0); 64 printf("the closest %d points are:\n",K); 65 for(int i = K;i>=1;--i) ans[i] = pq.top(),pq.pop(); 66 for(int i = 1;i<=K;++i) ans[i].print(); 67 } 68 } 69 70 }
例题:hdu2966(最近):http://acm.hdu.edu.cn/showproblem.php?pid=2966
1 #include<bits/stdc++.h> 2 using namespace std; 3 typedef long long ll; 4 const ll inf = 1000000000000000000; 5 const int M = 7; 6 const int N = 1e5 + 9; 7 int n,D,cmp_d; 8 ll ans; 9 struct Point{ 10 ll x[M]; 11 int son[2]; 12 int id; 13 void print(){ 14 for(int i = 0;i<D;++i) printf("%lld%c",x[i],i == D-1 ? ‘\n‘ : ‘ ‘); 15 } 16 }tr[N],Q,query_p[N]; 17 bool cmp(Point a,Point b){ 18 return a.x[cmp_d] < b.x[cmp_d]; 19 } 20 ll distance(Point a,Point b){ 21 if( a.id == b.id ) return inf; 22 ll res = 0; 23 for(int i = 0; i < D;++i) res += (a.x[i] - b.x[i]) * (a.x[i] - b.x[i]); 24 return res; 25 } 26 int build(int l,int r,int now_d){ 27 if( l > r ) return 0; 28 int m = (l+r)>>1; 29 cmp_d = now_d; 30 nth_element(tr+l,tr+m,tr+r+1,cmp); 31 tr[m].son[0] = build(l,m-1,(now_d+1)%D ); 32 tr[m].son[1] = build(m+1,r,(now_d+1)%D ); 33 return m; 34 } 35 void query(int pos,int now_d){ 36 if( !pos ) return; 37 ll cur_dis = distance( tr[pos],Q ); 38 ans = min(cur_dis,ans); 39 ll deta = tr[pos].x[now_d] - Q.x[now_d]; 40 int which = (deta < 0); 41 query(tr[pos].son[which],(now_d + 1)%D ); 42 if(ans > deta*deta ) query(tr[pos].son[which^1],(now_d+1)%D); 43 } 44 int main(){ 45 int T; scanf("%d",&T); 46 D = 2; 47 while(T--){ 48 int n; scanf("%d",&n); 49 for(int i = 1;i<=n;++i){ 50 tr[i].son[0] = tr[i].son[1] = 0; 51 tr[i].id = i; 52 for(int j = 0;j<D;++j) scanf("%lld",&tr[i].x[j]); 53 query_p[i] = tr[i]; 54 } 55 int root = build(1,n,0); 56 for(int i = 1;i<=n;++i){ 57 Q = query_p[i]; 58 ans = 1000000000000000000; 59 query(root,0); 60 printf("%lld\n",ans); 61 } 62 } 63 }
原文:https://www.cnblogs.com/xiaobuxie/p/12266744.html