利用Tensorflow2.0实现手写数字识别
前面两节课我们已经简单了解了神经网络的前向传播和反向传播工作原理,并且尝试用numpy实现了第一个神经网络模型。手动实现(深度)神经网络模型听起来很牛逼,实际上却是一个费时费力的过程,特别是在神经网络层数很多的情况下,多达几十甚至上百层网络的时候我们就很难手动去实现了。这时候可能我们就需要更强大的深度学习框架来帮助我们快速实现深度神经网络模型,例如Tensorflow/Pytorch/Caffe等都是非常好的选择,而近期大热的keras是Tensorflow2.0版本中非常重要的高阶API,所以本节课老shi打算先给大家简单介绍下Tensorflow的基础知识,最后借助keras来实现一个非常经典的深度学习入门案例——手写数字识别。废话不多说,马上进入正题。
什么是Tensorflow
Tensorflow是谷歌2015年推出的一款深度学习框架,与Pytorch类似,都是目前比较热门的深度学习框架。但Tensorflow与传统的模型搭建方式不同,它是采用数据流图的方式来计算, 所以我们首先得创建一个数据流图,然后再将我们的数据(数据以张量tensor的形式存在)放到数据流图中去计算,节点Nodes在图中表示数学操作,图中的边edges则表示在节点间相互联系的多维数组, 即张量(tensor)。训练模型时tensor会不断地从数据流图中的一个节点flow到另一个节点, 这也是Tensorflow名字的由来。计算图Graph规定了各个变量之间的计算关系,建立好的计算图需要编译以确定其内部细节,而此时的计算图还是一个“空壳子”,里面并没有任何实际的数据,只有当你把需要运算的输入数据放进去后,才能在整个模型中形成数据流,从而得到模型的输出值。打个比方,就像用管道搭建的供水系统,当你在拼接水管的时候,水管里面其实是没有水的,只有等所有的管子都接好了,才能进行供水。具体如下图所示
Tensorflow中的基本概念
计算图(Graph):计算图描述了计算的过程,Tensorflow使用计算图来表示计算任务。
张量(Tensor):Tensorflow使用tensor表示数据。每个tensor是一个类型化的多维数组。规模最小的张量是0阶张量,即标量,也就是一个数;当我们把一些数有序地排列起来,就形成了1阶张量,也就是向量;如果我们继续把一组向量有序排列起来,就得到了一个2阶张量,也就是一个矩阵 ;把矩阵堆起来就是3阶张量,也就得到了一个立方体,我们常见的3通道(3色RGB)的彩色图片也是一个立方体;如果我们继续把立方体堆起来,就得到一个4阶的张量,以此类推。
操作(op):计算图中的节点被称为op(operation的缩写),即操作 op=节点Nodes;一个op获得0个或多个Tensor,执行计算后,就会产生0个或多个Tensor。
会话(Session):计算图必须在“会话”的上下文中执行。会话将计算图的op分发到如CPU或GPU之类的设备上执行。
变量(Variable):运行过程中可以被改变的量,用于维护状态。
Tensorflow2.0相比Tensorflow1.x版本的改进
1、支持tf.data加载数据,使用tf.data创建的输入管道读取训练数据,支持从内存(Numpy)方便地输入数据;
2、取消了会话Session,由静态计算图变成动态计算图,直接打印结果,不需要执行会话的过程;
3、使用tf.keras构建、训练和验证模型,或使用Premade来验证模型,可以直接标准的打包模型(逻辑回归、随机森林),也可以直接使用tf.estimator API 。如果不想从头训练模型,可以使用迁移学习来训练一个使用TensorflowHub模块的Keras或Estimator;
4、使用分发策略进行分发训练,分发策略API可以在不更改定义的情况下,轻松在不同的硬件配置上分发和训练模型,支持一系列的硬件加速器,例如GPU、TPU等;
5、使用SaveModel作为模型保存模块,更好对接线上部署。
最后,我们使用Tensorflow2.0高阶API keras来实现深度学习经典入门案例——手写数字识别,以下是案例代码,有兴趣的同学可以跟着实现一遍。下节课给大家带来卷积神经网络CNN,敬请期待!!
#coding:utf8import numpy as npnp.random.seed(123)#后面只使用keras.model搭建一个简单的全连接网络模型,不用tf.keras中的特性,在此直接用import keras也可以from tensorflow import kerasfrom keras.datasets import mnistfrom keras.utils import np_utilsfrom keras.models import Sequentialfrom keras.layers import Dense,Activationfrom keras.optimizers import RMSprop
# 数据导入(x_train,y_train),(x_test,y_test) = mnist.load_data()print(x_train.shape,y_train.shape)print(x_test.shape,y_test.shape)
# 数据预处理x_train = x_train.reshape(x_train.shape[0],-1) / 255.0x_test = x_test.reshape(x_test.shape[0],-1) / 255.0y_train = np_utils.to_categorical(y_train,num_classes=10)y_test = np_utils.to_categorical(y_test,num_classes=10)
# 直接使用keras.Sequential()搭建全连接网络模型model = Sequential()model.add(Dense(128, input_shape=(784,)))model.add(Activation('relu'))model.add(Dense(10))model.add(Activation('softmax'))
#lr为学习率,epsilon防止出现0,rho/decay分别对应公式中的beta_1和beta_2rmsprop = RMSprop(lr=0.001,rho=0.9,epsilon=1e-08,decay=0.00001) model.compile(optimizer=rmsprop,loss='categorical_crossentropy',metrics=['accuracy'])print("---------------training--------------")model.fit(x_train,y_train,epochs=5,batch_size=32)print('n')print("--------------testing----------------")loss,accuracy = model.evaluate(x_test,y_test)print('loss:',loss)print('accuracy:',accuracy)
- 如何使用Python读取大文件
- 介绍一种非常好用汇总数据的方式GROUPING SETS
- 史上最大的CPU Bug(幽灵和熔断的OS&SQLServer补丁)
- 数据库副本的自动种子设定(自增长)
- Git 项目推荐 | 基于go+protobuff 实现的分布式
- ReflectASM-invoke,高效率java反射机制原理
- Web应用渗透测试-本地文件包含
- shiro权限控制(二):分布式架构中shiro的实现
- Groovy实现原理分析——准备工作
- HBCTF第一场2个pwn题的简单分析
- ACM竞赛之输入输出(以C与C++为例)
- 能让程序做的事情坚决不用人来做——批量修复markdownlint MD034警告
- swift demo1 tableview
- Swift Alamofire
- JavaScript 教程
- JavaScript 编辑工具
- JavaScript 与HTML
- JavaScript 与Java
- JavaScript 数据结构
- JavaScript 基本数据类型
- JavaScript 特殊数据类型
- JavaScript 运算符
- JavaScript typeof 运算符
- JavaScript 表达式
- JavaScript 类型转换
- JavaScript 基本语法
- JavaScript 注释
- Javascript 基本处理流程
- Javascript 选择结构
- Javascript if 语句
- Javascript if 语句的嵌套
- Javascript switch 语句
- Javascript 循环结构
- Javascript 循环结构实例
- Javascript 跳转语句
- Javascript 控制语句总结
- Javascript 函数介绍
- Javascript 函数的定义
- Javascript 函数调用
- Javascript 几种特殊的函数
- JavaScript 内置函数简介
- Javascript eval() 函数
- Javascript isFinite() 函数
- Javascript isNaN() 函数
- parseInt() 与 parseFloat()
- escape() 与 unescape()
- Javascript 字符串介绍
- Javascript length属性
- javascript 字符串函数
- Javascript 日期对象简介
- Javascript 日期对象用途
- Date 对象属性和方法
- Javascript 数组是什么
- Javascript 创建数组
- Javascript 数组赋值与取值
- Javascript 数组属性和方法
- 【STM32H7】第9章 RL-TCPnet调试方法(Event Recorder和串口两种)
- 【STM32F429】第9章 RL-TCPnet调试方法(Event Recorder和串口两种)
- 【STM32F407】第9章 RL-TCPnet V7.X调试方法(Event Recorder和串口两种)
- CentOS7的udev的绑定规则
- 案例:记录一则强制开库遭遇ORA-16433的处理过程
- mybatis升级为mybatis-plus踩到的坑
- Treepath
- linux 远程ssh免密登录
- npm 安装 electron taobao镜像 404错误 自用 实践笔记
- Asp.net Core 使用Jenkins + Dockor 实现持续集成、自动化部署(二):部署
- 队列的一种实现:循环队列
- StackExchange.Redis .net core Timeout performing 超时问题
- G1 垃圾回收器简单调优
- Docker安装官方Redis镜像并启用密码认证 实践笔记
- Asp.net Core 使用Jenkins + Dockor 实现持续集成、自动化部署(四):发布与回滚