简单易学的机器学习算法——K-近邻算法
一、近邻算法(Nearest Neighbors)
1、近邻算法的概念
近邻算法(Nearest Neighbors)是一种典型的非参模型,与生成方法(generalizing method)不同的是,在近邻算法中,通过以实例的形式存储所有的训练样本,假设有m个训练样本:
此时需要存储这m个训练样本,因此,近邻算法也称为基于实例的模型。
对于一个需要预测的样本,通过与存储好的训练样本比较,以较为相似的样本的标签作为近邻算法的预测结果。
2、近邻算法的分类
在近邻算法中,根据处理的问题的不同,可以分为:
- 近邻分类算法
- 近邻回归算法
在本篇博文中,主要介绍近邻分类算法。 注意:除了上述的监督式近邻算法,在近邻算法中,还有一类非监督的近邻算法。
二、近邻分类算法
1、近邻分类算法的概念
在近邻分类算法中,对于预测的数据,将其与训练样本进行比较,找到最为相似的K个训练样本,并以这K个训练样本中出现最多的标签作为最终的预测标签。
在近邻分类算法中,最主要的是K-近邻算法。
2、KNN算法概述
K-NN算法是最简单的分类算法,主要的思想是计算待分类样本与训练样本之间的差异性,并将差异按照由小到大排序,选出前面K个差异最小的类别,并统计在K个中类别出现次数最多的类别为最相似的类,最终将待分类样本分到最相似的训练样本的类中。与投票(Vote)的机制类似。
3、样本差异性
比较常用的差异性计算方法为欧式距离。欧式距离:样本
与样本
之间的欧式距离为:
4、KNN算法的流程
- 求预测样本与训练样本之间的相似性
- 依据相似性排序
- 选择前K个最为相似的样本对应的类别
- 得到预测的分类结果
三、K-近邻算法实现
1、Python实现
以手写字体MNIST的识别为例,对于测试集中的每一个样本预测其类别,对于手写字体,如下图所示:
k_nn.py
# coding:UTF-8
import cPickle as pickle
import gzip
import numpy as np
def load_data(data_file):
with gzip.open(data_file, 'rb') as f:
train_set, valid_set, test_set = pickle.load(f)
return train_set[0], train_set[1], test_set[0], test_set[1]
def cal_distance(x, y):
return ((x - y) * (x - y).T)[0, 0]
def get_prediction(train_y, result):
result_dict = {}
for i in xrange(len(result)):
if train_y[result[i]] not in result_dict:
result_dict[train_y[result[i]]] = 1
else:
result_dict[train_y[result[i]]] += 1
predict = sorted(result_dict.items(), key=lambda d: d[1])
return predict[0][0]
def k_nn(train_data, train_y, test_data, k):
# print test_data
m = np.shape(test_data)[0] # 需要计算的样本的个数
m_train = np.shape(train_data)[0]
predict = []
for i in xrange(m):
# 对每一个需要计算的样本计算其与所有的训练数据之间的距离
distance_dict = {}
for i_train in xrange(m_train):
distance_dict[i_train] = cal_distance(train_data[i_train, :], test_data[i, :])
# 对距离进行排序,得到最终的前k个作为最终的预测
distance_result = sorted(distance_dict.items(), key=lambda d: d[1])
# 取出前k个的结果作为最终的结果
result = []
count = 0
for x in distance_result:
if count >= k:
break
result.append(x[0])
count += 1
# 得到预测
predict.append(get_prediction(train_y, result))
return predict
def get_correct_rate(result, test_y):
m = len(result)
correct = 0.0
for i in xrange(m):
if result[i] == test_y[i]:
correct += 1
return correct / m
if __name__ == "__main__":
# 1、导入
print "---------- 1、load data ------------"
train_x, train_y, test_x, test_y = load_data("mnist.pkl.gz")
# 2、利用k_NN计算
train_x = np.mat(train_x)
test_x = np.mat(test_x)
print "---------- 2、K-NN -------------"
result = k_nn(train_x, train_y, test_x[:10,:], 10)
print result
# 3、预测正确性
print "---------- 3、correct rate -------------"
print get_correct_rate(result, test_y)
当取K=10时,对测试集中的10个数据样本的最终的预测准确性为:70%,预测值为:[7, 2, 1, 0, 9, 1, 9, 9, 8, 9],原始值为[7 2 1 0 4 1 4 9 5 9]。
2、Scikit-leanrn库
在Scikit-learn库中对K-NN算法有很好的支持,核心程序为:
clf = neighbors.KNeighborsClassifier(n_neighbors)
clf.fit(X, y)
四、K-NN算法中存在的问题及解决方法
1、计算复杂度的问题
在K-NN算法中,每一个预测样本需要与所有的训练样本计算相似度,计算量比较大。比较常用的方法有K-D树,局部敏感哈希等等
2、K-NN的均匀投票
在上述的K-NN算法中,最终对标签的选择是通过投票的方式决定的,在投票的过程中,每一个训练样本的投票的权重是相等的,可以对每个训练样本的投票加权,以期望最相似的样本有更高的决策权。
参考文献
1、1.6. Nearest Neighbors点击打开链接
- Silverlight性能优化
- WCF后续之旅(6): 通过WCF Extension实现Context信息的传递
- WCF后续之旅(6): 通过WCF Extension实现Context信息的传递
- 理性的相亲方法!精品课:《决策树》
- Asp.Net无刷新分页( jquery.pagination.js)
- 为什么网站需要用CDN来加速?
- Jmeter常用获取数据的几种方式
- [Silverlight 4 RC]RichTextBox概览
- WCF后续之旅(4):WCF Extension Point 概览
- Asp.Net无刷新上传并裁剪头像
- 用泛型的IEqualityComparer<T>接口去重复项
- python与office(一)
- Asp.net 后台添加CSS、JS、Meta标签(帮助类)
- 分享一下cookies操作(增、删、改、查)小经验
- JavaScript 教程
- JavaScript 编辑工具
- JavaScript 与HTML
- JavaScript 与Java
- JavaScript 数据结构
- JavaScript 基本数据类型
- JavaScript 特殊数据类型
- JavaScript 运算符
- JavaScript typeof 运算符
- JavaScript 表达式
- JavaScript 类型转换
- JavaScript 基本语法
- JavaScript 注释
- Javascript 基本处理流程
- Javascript 选择结构
- Javascript if 语句
- Javascript if 语句的嵌套
- Javascript switch 语句
- Javascript 循环结构
- Javascript 循环结构实例
- Javascript 跳转语句
- Javascript 控制语句总结
- Javascript 函数介绍
- Javascript 函数的定义
- Javascript 函数调用
- Javascript 几种特殊的函数
- JavaScript 内置函数简介
- Javascript eval() 函数
- Javascript isFinite() 函数
- Javascript isNaN() 函数
- parseInt() 与 parseFloat()
- escape() 与 unescape()
- Javascript 字符串介绍
- Javascript length属性
- javascript 字符串函数
- Javascript 日期对象简介
- Javascript 日期对象用途
- Date 对象属性和方法
- Javascript 数组是什么
- Javascript 创建数组
- Javascript 数组赋值与取值
- Javascript 数组属性和方法
- Maven快速入门
- TomCat安装及快速部署
- SpringCloud+MyBatis分页处理(前后端分离)
- 手把手教你搭建SpringCloud项目
- SpringCloud的@Value注解及GitLab配置使用
- 使用 cdk8s 与 Argo CD 进行 GitOps 实践
- 设计模式 | 模版方法
- Python 函数3000字使用总结
- 3D摇杆控制器一种简单实现!Cocos Creator 3D!
- 数据结构 | TencentOS-tiny中队列、环形队列、优先级队列的实现及使用
- RTOS内功修炼记(六)—— 任务间通信为什么不用全局变量?
- 程序员必备基础:加签验签
- 【Rust日报】2020-07-16 j4rs,一个在 Rust 中调用 Java 代码的 Crate
- Vue.js 3 正式进入 RC 阶段
- FeignClient注解及参数问题---SpringCloud微服务