KNN实现手写数字识别
时间:2022-04-23
本文章向大家介绍KNN实现手写数字识别,主要内容包括1 - 导入模块、2 - 导入数据及数据预处理、3 - 构建模型、基本概念、基础应用、原理机制和需要注意的事项等,并结合实例形式分析了其使用技巧,希望通过本文能帮助到大家理解应用这部分内容。
KNN实现手写数字识别
博客上显示这个没有Jupyter的好看,想看Jupyter Notebook的请戳KNN实现手写数字识别.ipynb
1 - 导入模块
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
from ld_mnist import load_digits
%matplotlib inline
2 - 导入数据及数据预处理
import tensorflow as tf
# Import MNIST data
from tensorflow.examples.tutorials.mnist import input_data
def load_digits():
mnist = input_data.read_data_sets("path/", one_hot=True)
return mnist
mnist = load_digits()
Extracting C:/Users/marsggbo/Documents/Code/ML/TF Tutorial/data/MNIST_datatrain-images-idx3-ubyte.gz
Extracting C:/Users/marsggbo/Documents/Code/ML/TF Tutorial/data/MNIST_datatrain-labels-idx1-ubyte.gz
Extracting C:/Users/marsggbo/Documents/Code/ML/TF Tutorial/data/MNIST_datat10k-images-idx3-ubyte.gz
Extracting C:/Users/marsggbo/Documents/Code/ML/TF Tutorial/data/MNIST_datat10k-labels-idx1-ubyte.gz
数据维度
print("Train: "+ str(mnist.train.images.shape))
print("Train: "+ str(mnist.train.labels.shape))
print("Test: "+ str(mnist.test.images.shape))
print("Test: "+ str(mnist.test.labels.shape))
Train: (55000, 784)
Train: (55000, 10)
Test: (10000, 784)
Test: (10000, 10)
mnist数据采用的是TensorFlow的一个函数进行读取的,由上面的结果可以知道训练集数据X_train有55000个,每个X的数据长度是784(28*28)。
x_train, y_train, x_test, y_test = mnist.train.images, mnist.train.labels, mnist.test.images, mnist.test.labels
展示手写数字
nums = 6
for i in range(1,nums+1):
plt.subplot(1,nums,i)
plt.imshow(x_train[i].reshape(28,28), cmap="gray")
3 - 构建模型
class Knn():
def __init__(self,k):
self.k = k
self.distance = {}
def topKDistance(self, x_train, x_test):
'''
计算距离,这里采用欧氏距离
'''
print("计算距离...")
distance = {}
for i in range(x_test.shape[0]):
dis1 = x_train - x_test[i]
dis2 = np.sqrt(np.sum(dis1*dis1, axis=1))
distance[str(i)] = np.argsort(dis2)[:self.k]
if i%1000==0:
print(distance[str(i)])
return distance
def predict(self, x_train, y_train, x_test):
'''
预测
'''
self.distance = self.topKDistance(x_train, x_test)
y_hat = []
print("选出每项最佳预测结果")
for i in range(x_test.shape[0]):
classes = {}
for j in range(self.k):
num = np.argmax(y_train[self.distance[str(i)][j]])
classes[num] = classes.get(num, 0) + 1
sortClasses = sorted(classes.items(), key= lambda x:x[1], reverse=True)
y_hat.append(sortClasses[0][0])
y_hat = np.array(y_hat).reshape(-1,1)
return y_hat
def fit(self, x_train, y_train, x_test, y_test):
'''
计算准确率
'''
print("预测...")
y_hat = self.predict(x_train, y_train, x_test)
# index_hat =np.argmax(y_hat , axis=1)
print("计算准确率...")
index_test = np.argmax(y_test, axis=1).reshape(1,-1)
accuracy = np.sum(y_hat.reshape(index_test.shape) == index_test)*1.0/y_test.shape[0]
return accuracy, y_hat
clf = Knn(10)
accuracy, y_hat = clf.fit(x_train,y_train,x_test,y_test)
print(accuracy)
预测...
计算距离...
[48843 33620 11186 22059 42003 9563 39566 10260 35368 31395]
[54214 4002 11005 15264 49069 8791 38147 47304 51494 11053]
[46624 10708 22134 20108 48606 19774 7855 43740 51345 9308]
[ 8758 47844 50994 45610 1930 3312 30140 17618 910 51918]
[14953 1156 50024 26833 26006 38112 31080 9066 32112 41846]
[45824 14234 48282 28432 50966 22786 40902 52264 38552 44080]
[24878 4655 20258 36065 30755 15075 35584 12152 4683 43255]
[48891 20744 47822 53511 54545 27392 10240 3970 25721 30357]
[ 673 17747 33803 20960 25463 35723 969 50577 36714 35719]
[ 8255 42067 53282 14383 14073 52083 7233 8199 8963 12617]
选出每项最佳预测结果
计算准确率...
0.9672
准确率略高。
- JavaScript 教程
- JavaScript 编辑工具
- JavaScript 与HTML
- JavaScript 与Java
- JavaScript 数据结构
- JavaScript 基本数据类型
- JavaScript 特殊数据类型
- JavaScript 运算符
- JavaScript typeof 运算符
- JavaScript 表达式
- JavaScript 类型转换
- JavaScript 基本语法
- JavaScript 注释
- Javascript 基本处理流程
- Javascript 选择结构
- Javascript if 语句
- Javascript if 语句的嵌套
- Javascript switch 语句
- Javascript 循环结构
- Javascript 循环结构实例
- Javascript 跳转语句
- Javascript 控制语句总结
- Javascript 函数介绍
- Javascript 函数的定义
- Javascript 函数调用
- Javascript 几种特殊的函数
- JavaScript 内置函数简介
- Javascript eval() 函数
- Javascript isFinite() 函数
- Javascript isNaN() 函数
- parseInt() 与 parseFloat()
- escape() 与 unescape()
- Javascript 字符串介绍
- Javascript length属性
- javascript 字符串函数
- Javascript 日期对象简介
- Javascript 日期对象用途
- Date 对象属性和方法
- Javascript 数组是什么
- Javascript 创建数组
- Javascript 数组赋值与取值
- Javascript 数组属性和方法
- 明亮解我“工厂模式无用”之惑
- 「源码分析」— 为什么枚举是单例模式的最佳方法
- 如何记忆 Spring Bean 的生命周期
- 系统学习Stream
- Java回调的四种写法(反射、直接调用、接口调用、Lamda表达式)
- 避开NullPointerException的10条建议
- REST服务,使用Dubbo还是SpringMVC?
- Linux系统下Anaconda的安装和使用教程
- Flutter Dojo设计之道——利用Github打造完善的开源项目
- 最强 Redis 客户端 lettuce 已支持 Redis6客户端
- 还在手动整理数据库文档?试试这个工具
- Elasticsearch 常见的 8 种错误及最佳实践
- Spark流式状态管理
- Scala中的IO操作及ArrayBuffer线程安全问题
- 设计模式之单例模式