数据结构和算法——kd树
一、K-近邻算法
K-近邻算法是一种典型的无参监督学习算法,对于一个监督学习任务来说,其mm个训练样本为:
{(X(1),y(1)),(X(2),y(2)),⋯,(X(m),y(m))}
left { left ( X^{left ( 1 right )},y^{left ( 1 right )} right ),left ( X^{left ( 2 right )},y^{left ( 2 right )} right ),cdots ,left ( X^{left ( m right )},y^{left ( m right )} right ) right }
在K-近邻算法中,无需利用训练样本学习出统一的模型,对于一个新的样本,如XX,通过比较样本XX与mm个训练样本的相似度,选择出kk个最相似的样本,并以这kk个样本的标签作为样本XX的标签。
在如上的描述中,样本XX需要分别与mm个训练样本计算相似度,通常,使用的相似度的计算方法为欧式距离,即对于样本Xi={xi,1,xi,2,⋯,xi,n}X_i=left { x_{i,1},x_{i,2},cdots ,x_{i,n} right }和样本Xj={xj,1,xj,2,⋯,xj,n}X_j=left { x_{j,1},x_{j,2},cdots ,x_{j,n} right },其两者之间的相似度为:
S=∑t=1n(xi,t−xj,t)2−−−−−−−−−−−−−√
S=sqrt{sum_{t=1}^{n}left ( x_{i,t}-x_{j,t} right )^2}
对于K-近邻算法的具体过程,可以参见博文简单易学的机器学习算法——K-近邻算法。
在K-近邻算法的计算过程中,通过暴力的对每一对样本计算其相似度是非常好费时间的,那么是否存在一种方法,能够加快计算的速度?kd树便是其中的一种方法。
二、kd树
kd树是一种对kk维空间中的实例点进行存储以便对其进行快速检索的树形数据结构,且kd树是一种二叉树,表示对kk维空间的一个划分。
1、二叉排序树
在数据结构中,二叉排序树又称二叉查找树或者二叉搜索树。其定义为:二叉排序树,或者是一棵空树,或者是具有下列性质的二叉树:
- 若它的左子树不空,则左子树上所有结点的值均小于它的根结点的值;
- 若它的右子树不空,则右子树上所有结点的值均大于它的根结点的值;
- 它的左、右子树也分别为二叉排序树。
一个典型的二叉排序树的例子如下图所示:
在二叉排序树中,若以中序遍历,则得到的是按照值大小排序的结果,即1->3->4->6->7->8->10->13->14。
如果需要检索7,则从根结点开始:
- 7<87<8->左子树
- 7>37>3->右子树
- 7>67>6->右子树
- 7=77=7->查找结束
但是,对于二叉排序树的建立,若构建二叉排序树的顺序为基本有序时,如按照1->3->4->6->7->8->10->13->14构建二叉排序树,会得到如下的结果:
这样的话,检索效率会下降,为了避免这样的情况的出现,会对二叉树设置一些条件,如平衡二叉树。对于二叉排序树的更多内容,可以参见数据结构和算法——二叉排序树。
2、kd树的概念
kd树与二叉排序树的基本思想类似,与二叉排序树不同的是,在kd树中,每一个节点表示的是一个样本,通过选择样本中的某一维特征,将样本划分到不同的节点中,如对于样本{(7,2),(5,4),(9,6)}left { left ( 7,2 right ),left ( 5,4 right ),left ( 9,6 right ) right }, 考虑数据的第一维,首先,根节点为{(7,2)}left { left ( 7,2 right )right },由于样本{(5,4)}left { left ( 5,4 right )right }的第一维55小于77,因此,样本{(5,4)}left { left ( 5,4 right )right }在根节点的左子树上,同理,样本{(9,6)}left { left ( 9,6 right )right }在根节点的右子树上。通过第一维可以构建如下的二叉树模型:
在kd树的基本操作中,主要包括kd树的建立和kd树的检索两个部分。
3、kd树的建立
构造kd树相当于不断地用垂直于坐标轴的超平面将kk维空间切分成一系列的kk维超矩阵区域。选择划分节点的方法主要有两种:
- 顺序选择,即按照数据的顺序依次在kd树中插入节点;
- 选择待划分维数的中位数为划分的节点。在kd树的构建过程中,为了防止出现只有左子树或者只有右子树的情况出现,通常对于每一个节点,选择样本中的中位数作为切分点。这样构建出来的kd树时平衡的。
在李航的《统计机器学习》P41中有提到:平衡的kd树搜索时的效率未必是最优的。
在构建kd树的过程中,也可以根据插入数据的顺序构建kd树,以二维数据集为例,其数据的顺序依次为:
{(3,6),(7,5),(3,1),(6,2),(9,1),(2,7)}
left { left ( 3,6 right ),left ( 7,5 right ),left ( 3,1 right ),left ( 6,2 right ),left ( 9,1 right ),left ( 2,7 right ) right }
对于如上的二维数据集,构建kd树:
- 选择一维最为切分的维度,如选择第00维,第一个数为(3,6)left ( 3,6 right ),其第00维的值为33,以(3,6)left ( 3,6 right )作为kd树的根结点,若第00维的值大于33为右子树,否则插入到左子树中;
- 对后续的节点依次判断,如(7,5)left ( 7,5 right ),选择第00维,其值为77,大于33,插入到根结点的右子树中,设置其维数为除了第00维以外的任一维。。。
按照如上的过程,我们划分出来的kd树如下图所示:
此时,将样本按照特征空间划分如下图所示:
由以上的计算过程可以看出对于树中节点,需要有数据项,当前节点的比较维度,指向左子树的指针和指向右子树的指针,可以设置其结构如下:
#define MAX_LEN 1024
typedef struct KDtree{
double data[MAX_LEN]; // 数据
int dim; // 选择的维度
struct KDtree *left; // 左子树
struct KDtree *right; // 右子树
}kdtree_node;
构造kd树的函数声明为:
int kdtree_insert(kdtree_node *&tree_node, double *data, int layer, int dim);
函数的具体实现如下:
// 递归构建kd树,通过节点所在的层数控制选择的维度
int kdtree_insert(kdtree_node * &tree_node, double *data, int layer, int dim){
// 空树
if (NULL == tree_node){
// 申请空间
tree_node = (kdtree_node *)malloc(sizeof(kdtree_node));
if (NULL == tree_node) return 1;
//插入元素
for (int i = 0; i < dim; i ++){
(tree_node->data)[i] = data[i];
}
tree_node->dim = layer % (dim);
tree_node->left = NULL;
tree_node->right = NULL;
return 0;
}
// 插入左子树
if (data[tree_node->dim] <= (tree_node->data)[tree_node->dim]){
return kdtree_insert(tree_node->left, data, ++layer, dim);
}
// 插入右子树
return kdtree_insert(tree_node->right, data, ++layer, dim);
}
当构建好了kd树后,需要对kd树进行遍历,在这里,实现了两种kd树的遍历方法:
- 先序遍历
- 中序遍历
对于先序遍历,其函数的声明为:
void kdtree_print(kdtree_node *tree, int dim);
函数的具体实现为:
void kdtree_print(kdtree_node *tree, int dim){
if (tree != NULL){
fprintf(stderr, "dim:%dn", tree->dim);
for (int i = 0; i < dim; i++){
fprintf(stderr, "%lft", (tree->data)[i]);
}
fprintf(stderr, "n");
kdtree_print(tree->left, dim);
kdtree_print(tree->right, dim);
}
}
对于中序遍历,其函数的声明为:
void kdtree_print_in(kdtree_node *tree, int dim);
函数的具体实现为:
void kdtree_print_in(kdtree_node *tree, int dim){
if (tree != NULL){
kdtree_print_in(tree->left, dim);
fprintf(stderr, "dim:%dn", tree->dim);
for (int i = 0; i < dim; i++){
fprintf(stderr, "%lft", (tree->data)[i]);
}
fprintf(stderr, "n");
kdtree_print_in(tree->right, dim);
}
}
4、kd树的检索
与二叉排序树一样,在kd树中,将样本划分到不同的空间中,在查找的过程中,由于查找在某些情况下仅需查找部分的空间,这为查找的过程节省了对大部分数据点的搜索的时间,对于kd树的检索,其具体过程为:
- 从根节点开始,将待检索的样本划分到对应的区域中(在kd树形结构中,从根节点开始查找,直到叶子节点,将这样的查找序列存储到栈中)
- 以栈顶元素与待检索的样本之间的距离作为最短距离min_distance
- 执行出栈操作:
- 向上回溯,查找到父节点,若父节点与待检索样本之间的距离小于当前的最短距离min_distance,则替换当前的最短距离min_distance
- 以待检索的样本为圆心(二维,高维情况下是球心),以min_distance为半径画圆,若圆与父节点所在的平面相割,则需要将父节点的另一棵子树进栈,重新执行以上的出栈操作
- 直到栈为空
以查找(6,3)left ( 6,3 right )为例,首先,我们需要找到待查找的样本所在的搜索空间,搜索空间如下图中的黑色区域所示:
其对应的进栈序列为:{(3,6),(7,5),(6,2)}left { left ( 3,6 right ),left ( 7,5 right ),left ( 6,2 right ) right }。
此时,以到(6,2)left ( 6,2 right )之间的距离为最短距离,最短距离min_distance为1,对栈顶元素出栈,此时栈中的序列为:{(3,6),(7,5)}left { left ( 3,6 right ),left ( 7,5 right ) right }。以待检索样本(6,3)left ( 6,3 right )为圆心,1为半径画圆,圆与(6,2)left ( 6,2 right )所在平面相割,如下图所示:
此时,需要检索以(6,2)left ( 6,2 right )为根节点的另外一棵子树,即需要将(9,1)left ( 9,1 right )进栈,此时,栈中的序列为:{(3,6),(7,5),(9,1)}left { left ( 3,6 right ),left ( 7,5 right ),left ( 9,1 right ) right }。
注意:若需要进栈的子树中有很多节点,则根据需要比较的元素的大小,将直到叶节点的所有节点都进栈,这一点在很多地方都写得不清楚。
按照上述的步骤,再执行出栈的操作,直到栈为空。
检索过程的函数声明为:
void search_nearest(kdtree_node *tree, double *data_search, int dim, double *result);
函数的具体实现为:
void search_nearest(kdtree_node *tree, double *data_search, int dim, double *result){
// 一直找到叶子节点
fprintf(stderr, "nstart searching....n");
stack<kdtree_node *> st;
kdtree_node *p = tree;
while (p->left != NULL || p->right != NULL){
st.push(p);// 将p压栈
if (data_search[p->dim] <= (p->data)[p->dim]){// 选择左子树
// 判断左子树是否为空
if (p->left == NULL) break;
p = p->left;
}else{ // 选择右子树
if (p->right == NULL) break;
p = p->right;
}
}
// 现在与栈中的数据进行对比
double min_distance = distance(data_search, p->data, dim);// 与根结点之间的距离
fprintf(stderr, "init: %lfn", min_distance);
copy2result(p->data, result, dim);
// 打印最优值
for (int i = 0; i < dim; i++){
fprintf(stderr, "%lft", result[i]);
}
fprintf(stderr, "n");
double d = 0;
while (st.size() > 0){
kdtree_node *q = st.top();// 找到栈顶元素
st.pop(); // 出栈
// 判断与父节点之间的距离
d = distance(data_search, q->data, dim);
if (d <= min_distance){
min_distance = d;
copy2result(q->data, result, dim);
}
// 判断与分隔面是否相交
double d_line = distance_except_dim(data_search, q->data, q->dim); // 到平面之间的距离
if (d_line < min_distance){ // 相交
// 如果本来在右子树,现在查找左子树
// 如果本来在左子树,现在查找右子树
if (data_search[q->dim] > (q->data)[q->dim]){
// 选择左子树
if (q->left != NULL) q = q->left;
else q = NULL;
}else{
// 选择右子树
if (q->right != NULL) q = q->right;
else q = NULL;
}
if (q != NULL){
while (q->left != NULL || q->right != NULL){
st.push(q);
if (data_search[q->dim] <= (q->data)[q->dim]){
if (q->left == NULL) break;
q = q->left;
}else{
if (q->right == NULL) break;
q = q->right;
}
}
if (q->left == NULL && q->right == NULL) st.push(q);
}
}
}
}
在函数的实现中,需要用到的函数为:
- 两个样本之间的距离
double distance(double *a, double *b, int dim){
double d = 0.0;
for (int i = 0; i < dim; i ++){
d += (a[i] - b[i]) * (a[i] - b[i]);
}
return d;
}
- 待检索的样本到平面之间的距离
double distance_except_dim(double *a, double *b, int except_dim){
double d = (a[except_dim] - b[except_dim]) * (a[except_dim] - b[except_dim]);
return d;
}
- 复制最优的结果
void copy2result(double *a, double *result, int dim){
for (int i = 0; i < dim; i ++){
result[i] = a[i];
}
}
三、测试
利用如上的测试集,我们构建kd树,并在kd树中查找(6,3)left ( 6,3 right ),测试代码如下:
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include "kdtree.h"
// 解析特征
int parse_feature(char *p, double *fea, int *dim){
// 解析特征
char *q = p;
int i = 0;
while ((q = strchr(p, 't')) != NULL){
*q = 0;
fea[i] = atof(p);
//fprintf(stderr, "atof(p):%lfn", atof(p));
p = q + 1;
//r = r + 1;
i += 1;
}
// 解析最后一个
fea[i] = atof(p);
*dim = i + 1;
//fprintf(stderr, "atof(p):%lfn", atof(p));
//fprintf(stderr, "fea:%lft%lfn", fea[0], fea[1]);
}
int main(){
kdtree_node *tree_node = NULL;
// 从文件中读入数据
FILE *fp = fopen("data.txt", "r");
char feature[MAX_LEN];
double data[MAX_LEN];
int data_dim = 0; // 数据的维数
double data_search[2] = {6.0, 3.0};
while (fgets(feature, MAX_LEN, fp)){
fprintf(stderr, "%s", feature);
parse_feature(feature, data, &data_dim);
fprintf(stderr, "distance: %lfn", distance(data, data_search, data_dim));
// 插入到kd树中
kdtree_insert(tree_node, data, 0, data_dim);
}
fclose(fp);
fprintf(stderr, "dim:%dn", data_dim);
fprintf(stderr, "insert_okn");
// test
kdtree_print(tree_node, data_dim);
printf("n");
kdtree_print_in(tree_node, data_dim);
double result[2];
search_nearest(tree_node, data_search, data_dim, result);
fprintf(stderr, "n the final result: ");
for (int i = 0; i < data_dim; i++){
fprintf(stderr, "%lft", result[i]);
}
fprintf(stderr, "n");
return 0;
}
以上的代码以上处至Github,其地址为:kd-tree。若有不对的地方,欢迎指正。
参考文献
- K近邻算法基础:KD树的操作
- k近邻法的C++实现:kd树
- An intoductory tutorial on kd-trees
- Range Searching using Kd Tree
- 最近邻算法的实现:k-d tree
- 从K近邻算法、距离度量谈到KD树、SIFT+BBF算法
- 《统计机器学习》
- CVE-2017-3085:Adobe Flash泄漏Windows用户凭证
- hbase源码系列(九)StoreFile存储格式
- 如何确定恶意软件是否在自己的电脑中执行过?
- Carbondata源码系列(二)文件格式详解
- 挖洞经验 | 记一次针对Twitter(Periscope)API 的有趣挖洞经历
- 设计模式学习(二): 观察者模式 (C#)
- Carbondata源码系列(一)文件生成过程
- BoopSuite:基于Python编写的无线安全审计套件
- 设计模式学习(一):多用组合少用继承(C#)
- 在asp.net web api 2 (ioc autofac) 使用 Serilog 记录日志
- hbase源码系列(十三)缓存机制MemStore与Block Cache
- hbase源码系列(十四)Compact和Split
- 设计模式学习(四): 1.简单工厂 (附C#实现)
- 从头编写 asp.net core 2.0 web api 基础框架 (5) EF CRUD
- JavaScript 教程
- JavaScript 编辑工具
- JavaScript 与HTML
- JavaScript 与Java
- JavaScript 数据结构
- JavaScript 基本数据类型
- JavaScript 特殊数据类型
- JavaScript 运算符
- JavaScript typeof 运算符
- JavaScript 表达式
- JavaScript 类型转换
- JavaScript 基本语法
- JavaScript 注释
- Javascript 基本处理流程
- Javascript 选择结构
- Javascript if 语句
- Javascript if 语句的嵌套
- Javascript switch 语句
- Javascript 循环结构
- Javascript 循环结构实例
- Javascript 跳转语句
- Javascript 控制语句总结
- Javascript 函数介绍
- Javascript 函数的定义
- Javascript 函数调用
- Javascript 几种特殊的函数
- JavaScript 内置函数简介
- Javascript eval() 函数
- Javascript isFinite() 函数
- Javascript isNaN() 函数
- parseInt() 与 parseFloat()
- escape() 与 unescape()
- Javascript 字符串介绍
- Javascript length属性
- javascript 字符串函数
- Javascript 日期对象简介
- Javascript 日期对象用途
- Date 对象属性和方法
- Javascript 数组是什么
- Javascript 创建数组
- Javascript 数组赋值与取值
- Javascript 数组属性和方法
- Tensorflow与Keras自适应使用显存方式
- Python类及获取对象属性方法解析
- Keras实现DenseNet结构操作
- python中format函数如何使用
- keras得到每层的系数方式
- 解决TensorFlow调用Keras库函数存在的问题
- php判断电子邮件是否正确方法
- python db类用法说明
- python中wheel的用法整理
- 使用Keras训练好的.h5模型来测试一个实例
- python中查看.db文件中表格的名字及表格中的字段操作
- Ubuntu 16.04中Laravel5.4升级到5.6的步骤
- PHP SESSION机制的理解与实例
- Yii支持多域名cors原理的实现
- PHP实现的pdo连接数据库并插入数据功能简单示例