0%

KD-Tree

使用KNN算法,需要找到样本点周围最近的N个点。最简单的方法是求出所有距离,然后找出前K大。然而点的数量巨大时,计算量会非常大。为了优化KNN算法,可以采用kd-tree(k维树)。它可以在各个维度的空间内对较大量的点进行检索。

kd树是二叉树,构造过程简单来说就是不断用垂直于坐标轴的超平面将k维空间切分,构成许多超矩形区域。每个节点都对应一个超平面,这个超平面过节点中存储的点并且将k维空间分成两部分。

平衡kd树的构建是递归的,每次选择一个区间内的中位数大(后面解释)的点为根节点,然后中位数点左侧和右侧区间递归这个构建过程。直到区间大小为一。

中位数大的点就是指如果按照一定规则给点排序,排序后下标是中位数的那个点(中间的那个点)。构建kd树时给某区间点的排序规则是:位于第x层的节点比较第(x%k+1)维坐标的大小。比如:如果是构建整个树的根节点,则所有点都参与排序,比较的是它们的第一维的大小。位于中间的点就是根节点。

查询前n近的算法简单来说就是:用一个大小位n的大顶堆保存前n近的点。先从根节点出发找到给定点所在空间对应的kd树的叶子节点,并且在找这个点的时候把路途经过的点加入堆里。如果堆满了,而且新经过的点离待查询点更近一些,那就把堆顶的点去掉,并添加进这个点(注意新的点不一定还在堆顶)。这时候,堆中所有点都包含在以给定点为圆(球)心,以给定点到堆顶点为半径的圆(球)内。之后从叶子节点一层层返回,每到一层的一个父节点,就看另一个叶子节点所在的区域是不是与这个圆相交。如果相交说明这一侧可能有更近的点,那么就进入这一侧搜寻更近的点。

查询过程也可以用递归实现。判断与圆(球)相交的方法:圆(球)心到超平面的距离小于半径。

在网上看到的资料大多没有简单的代码,而且缺少注释。许多代码是复制一个开源C++库的,作为学习来说源代码的结构有点复杂,不太适合学习。维基上有更详细的解释和许多有用的学习资料连接,而且有python版的实现。

为了自己实现一个简单的kd树练练手,我从网上搜了一道杭电OJ上的题。对于KNN算法来说很实用的题目,要求就是给一些点,找离目标点前M近的点。只涉及树的建立和查询。

这道OJ题的代码:

#include <bits/stdc++.h>

using namespace std;

#define MAX_DIM 5
#define DIS(X) ((X)*(X))

int n_dim;    //当前所比较的维度,分割面分割的维度

struct Point{
    int coord[MAX_DIM]; //坐标
    Point *lft, *rgt;   //树的左右节点指针

    Point(int k){
        lft=rgt=NULL;
        for(int i=0; i<k; i++)
            scanf("%d", &coord[i]);
    }

    Point(){
        lft=rgt=NULL;
    }

    inline bool operator<(const Point &b)const{
        return coord[n_dim]<b.coord[n_dim];
    }
};

struct kdTree{
    vector<Point> allp; //全体点
    priority_queue<pair<double, Point*> > *resultq; //查到的点
    int dim;    //空间的维度
    Point *root;    //树根指针

    kdTree(int n, int k){
        resultq=NULL;
        root=NULL;
        dim=k;
        for(int i=0; i<n; i++){
            allp.push_back(Point(k));
        }
        build(0, allp.size()-1, 0, root);
    }

    void query(int m, Point &p){
        Point res[20];
        resultq=new priority_queue<pair<double, Point*> >;
        queryInner(p, m, 0, root);
        //给出查询结果
        printf("the closest %d points are:\n", m);
        for(int n=0; !resultq->empty(); n++){
            res[n]=*(resultq->top().second);
            resultq->pop();
        }
        for(int n=m-1; n>=0;n--){
            for(int i=0; i<dim; i++)
                printf("%d%c", res[n].coord[i], i==dim-1?'\n':' ');
        }
        delete resultq;
        resultq=NULL;
    }

    void build(int l, int r, int dep, Point* &rt){
        if(l>r)return;
        int mid=(l+r)>>1;
        n_dim=dep%dim;  //存储分割面分割的是那个维度
        nth_element(allp.begin()+l, allp.begin()+mid, allp.begin()+r+1);

        rt=&allp[mid];  //把空间的点接到树上
        build(l, mid-1, dep+1, rt->lft);
        build(mid+1, r, dep+1, rt->rgt);
    }

    void queryInner(Point &p, int m, int dep, Point *rt){
        if(rt==NULL)return;
        pair<double, Point*> tmp=make_pair(0.0, rt); //计算到被查点的距离,准备构建结果队列
        for(int i=0; i<dim; i++)
            tmp.first+=DIS(rt->coord[i]-p.coord[i]);

        int now_dim=dep%dim;
        bool flg=false;
        Point *go=rt->lft, *go_another=rt->rgt;

        if(p.coord[now_dim]>=rt->coord[now_dim])
            swap(go, go_another);   //go代表被查点所在的一侧
        if(go)
            queryInner(p, m, dep+1, go);
        if((int)resultq->size()<m){
            resultq->push(tmp);
            flg=true;
            //查到的结果不够,一定向另一侧递归
        }else{
            if(tmp.first<resultq->top().first){
                resultq->pop();
                resultq->push(tmp);
            }//发现了更近的点
            /*待查询点与最远点形成的超球
                与分割空间的超平面相交,向不是所在的一侧递归
            */
            if(DIS(p.coord[now_dim] - rt->coord[now_dim]) < resultq->top().first)
                flg=true;
        }
        if(go_another && flg)
            queryInner(p, m, dep+1, go_another);
    }
    
};


int main(){
    //freopen("1.txt", "r" ,stdin);
    int n, k;
    while(scanf("%d%d", &n, &k)!=EOF){
        kdTree *tree = new kdTree(n, k);
        int t;
        scanf("%d", &t);
        while(t--){
            Point tmp=Point(k);
            int m;
            scanf("%d", &m);
            tree->query(m, tmp);
        }
        delete tree;
    }
    return 0;
}

过段时间会再尝试一下用python写一个kd树。了解到kd树也是从《统计学习方法》上看到的。但是sklearn库还提供了ball-tree。据说比kd树还好。sklearn上现成的算法确实很高效,要远远比自己写的算法快,而且还提供了不少额外功能。但我估计可能是它底层有C/C++优化的原因。

网上搜OJ题解的时候看到所有的人都是开了四倍最大点数的定长数组写的,代码非常短。但其实根本没有必要(也许做比赛有必要吧,但我只求完成功能)。实际上,上面这份代码无论是消耗的内存空间还是执行时间都比开数组的方法小。

pic1