使用Keras创建一个卷积神经网络模型,可对手写数字进行识别
在过去的几年里,图像识别研究已经达到了惊人的精确度。不可否认的是,深度学习在这个领域击败了传统的计算机视觉技术。
将神经网络应用于MNIST的数据集以识别手写的数字这种方法将所有的图像像素传输到完全连接的神经网络。该方法在测试集上的准确率为98.01%。这个成功率虽然看上去不错,但不是完美的。
应用卷积神经网络可以产生更成功的结果。与传统的方法相比,重点部分的图像像素将被传输到完全连接的神经网络,而不是所有的图像像素。一些滤镜应该被应用到图片中去检测重点部分的像素。
Keras是一个使用通用深度学习框架的API,并且可以更容易地构建深度学习模型。它还减少了代码的复杂性。我们可以编写更短的代码来在Keras中实现同样的目的。同样,相同的Keras代码可以在不同的平台上运行,比如TensorFlow或Theano。你所需要的只是更改配置,以切换深度学习框架。在本文中,我们将使用Keras来创建一个卷积神经网络模型。
首先,我们将导入所需的Keras库:
import keras
from keras.datasets import mnist
from keras.models import Sequential
from keras.layers import Dense, Activation, Flatten
from keras.layers import Conv2D, MaxPooling2D
from keras.preprocessing.image import ImageDataGenerator
其次,我们将加载mnist数据集。这个数据集已经作为训练集和测试集被分离了。训练集和数据集包括功能部件和标签。
(x_train, y_train), (x_test, y_test) = mnist.load_data()
第三,Keras要求我们在3D矩阵上进行输入特征的工作。因此,我们将把训练集和测试集的特征转换为3D矩阵。输入特征是大小为28×28的二维矩阵。这些矩阵保持不变,我们只添加一个虚拟维度,矩阵就会被转换成28x28x1。此外,输入特征必须在0到1之间。这就是为什么,特征会被划分为255,从而标准化[0, 1]。
x_train = x_train.reshape(x_train.shape[0], 28, 28, 1)
x_test = x_test.reshape(x_test.shape[0], 28, 28, 1)
x_train = x_train.astype('float32')
x_test = x_test.astype('float32')
x_train /= 255 #inputs have to be between [0, 1]
x_test /= 255
数据集标签在0到9的范围内。Keras让我们在二进制类标签上工作。下面的块将把标签转换成二进制格式。(例如,标签2将被表示为0010000000)
y_train = keras.utils.to_categorical(y_train, num_classes)
y_test = keras.utils.to_categorical(y_test, num_classes)
这不是必须的,但我们会在结构上保持“忠诚”。卷积和合并操作将被应用两次。在那之后,学习的功能将被转移到一个由一个隐藏层组成的完全连接的神经网络。你可以更改网络的结构,并监视对准确性的影响。
卷进神经网络流程
现在,我们将构建卷积神经网络的结构。
model = Sequential()
#1st convolution layer
model.add(Conv2D(32, (3, 3) #apply 32 filters size of (3, 3)
, input_shape=(28,28,1)))
model.add(Activation('relu'))
model.add(MaxPooling2D(pool_size=(2,2)))
#2nd convolution layer
model.add(Conv2D(64,(3, 3))) #apply 64 filters size of (3x3)
model.add(Activation('relu'))
model.add(MaxPooling2D(pool_size=(2,2)))
model.add(Flatten())
# Fully connected layer. 1 hidden layer consisting of 512 nodes
model.add(Dense(512))
model.add(Activation('relu'))
#10 outputs
model.add(Dense(10, activation='softmax'))
你可能会注意到,完全连接的神经网络的输出层连接到卷积神经网络的输出层,而非线性函数。这个函数应该是softmax函数。这样,输出值在[0, 1]之间标准化。而且,输出的和总是等于1。最后,最大索引将触发结果。
标准数据集由60000个实例组成。在个人计算机上很难处理好所有的实例。这就是为什么,我更喜欢用随机选择的方法来训练网络。如果你有时间或很好的的硬件,你也许会跳过这一步,并且希望在所有实例上工作。
gen = ImageDataGenerator()
train_generator = gen.flow(x_train, y_train, batch_size=batch_size)
现在,是时候训练网络了。
model.compile(loss='categorical_crossentropy'
, optimizer=keras.optimizers.Adam()
, metrics=['accuracy']
)
model.fit_generator(train_generator
, steps_per_epoch=batch_size
, epochs=epochs,
validation_data=(x_test, y_test))
一旦网络被训练,我们就可以怀疑成功标准。
score = model.evaluate(x_test, y_test, verbose=0)
print('Test loss:', score[0])
print('Test accuracy:', 100*score[1])
经典的完全连通的神经网络获得了98.01%的准确率,而卷积神经网络的准确率则超过了99%。这是一个令人难以置信的结果。所有这些测试都是在一个Core i7 CPU上执行的。此外,增加批量大小和epoch会提高准确率。
最后得分
最后,我创建了具有以下配置的模型:
batch_size = 250
epochs = 10
因此,图像识别研究将会被卷积神经网络进一步发展。
本文所用的代码:https://github.com/serengil/tensorflow-101/blob/master/python/HandwrittenDigitRecognitionUsingCNNWithKeras.py
- Spring+SpringMVC+MyBatis+easyUI整合优化篇(十二)数据层优化-explain关键字及慢sql优化
- 高吞吐koa日志中间件
- 关于SQLRecoverableException问题的排查和分析(r4笔记第13天)
- Spring+SpringMVC+MyBatis+easyUI整合优化篇(十三)数据层优化-表规范、索引优化
- node中的Stream-Readable和Writeable解读
- Spring+SpringMVC+MyBatis+easyUI整合进阶篇(六)一定要RESTful吗?
- 深入node之Transform
- 巧用shell脚本统计磁盘使用情况(r4笔记第12天)
- 使用fasttext实现文本处理及文本预测
- 关于导入导出sequence(r4笔记第11天)
- Spring+SpringMVC+MyBatis整合进阶篇(四)RESTful实战(前端代码修改)
- Nodejs cluster模块深入探究
- org.thymeleaf.exceptions.TemplateProcessingException: Exception evaluating SpringEL expression
- 巧用分析函数循序渐进解决实际问题 (r4笔记第10天)
- JavaScript 教程
- JavaScript 编辑工具
- JavaScript 与HTML
- JavaScript 与Java
- JavaScript 数据结构
- JavaScript 基本数据类型
- JavaScript 特殊数据类型
- JavaScript 运算符
- JavaScript typeof 运算符
- JavaScript 表达式
- JavaScript 类型转换
- JavaScript 基本语法
- JavaScript 注释
- Javascript 基本处理流程
- Javascript 选择结构
- Javascript if 语句
- Javascript if 语句的嵌套
- Javascript switch 语句
- Javascript 循环结构
- Javascript 循环结构实例
- Javascript 跳转语句
- Javascript 控制语句总结
- Javascript 函数介绍
- Javascript 函数的定义
- Javascript 函数调用
- Javascript 几种特殊的函数
- JavaScript 内置函数简介
- Javascript eval() 函数
- Javascript isFinite() 函数
- Javascript isNaN() 函数
- parseInt() 与 parseFloat()
- escape() 与 unescape()
- Javascript 字符串介绍
- Javascript length属性
- javascript 字符串函数
- Javascript 日期对象简介
- Javascript 日期对象用途
- Date 对象属性和方法
- Javascript 数组是什么
- Javascript 创建数组
- Javascript 数组赋值与取值
- Javascript 数组属性和方法
- 小程序顶部导航栏,可滑动,可动态选中放大
- 小程序不同页面的异步回调,callback和promise的使用讲解
- java入门019~springboot批量导入excel数据到mysql
- Java点餐系统和点餐小程序新加微信消息推送功能
- Java点餐系统和点餐小程序新加排号等位功能
- IDEA上给文件添加姓名,日期,版本号
- matlab机器人工具箱安装与卸载
- 浅谈Linux下修改/设置环境变量JAVA_HOME的方法
- Linux服务器配置多个svn仓库流程详解
- linux服务器显卡崩溃解决方案
- LINUX查看进程的4种方法(小结)
- Linux下的多线程编程实例解析
- CentOS使用expect批量远程执行脚本和命令
- Centos8最小化部署安装OpenStack Ussuri的详细教程
- 详解Xshell 常见问题及相关配置