TensorFlow 2.0 快速搭建神经网络
时间:2020-08-01
本文章向大家介绍TensorFlow 2.0 快速搭建神经网络,主要包括TensorFlow 2.0 快速搭建神经网络使用实例、应用技巧、基本知识点总结和需要注意事项,具有一定的参考价值,需要的朋友可以参考一下。
tf.keras 是 TensorFlow2 引入的高度封装框架,可以快速搭建神经网络模型。下面介绍一些常用API,更多内容可以参考官方文档:tensorflow
1 tf.keras 搭建神经网络六步法
- import
- train, test
- model = tf.keras.models.Sequential
- model.compile
- model.fit
- model.summary
1.1 import 相关模块
如 import tensorflow as tf
1.2 指定输入网络的训练集和测试集
如指定训练集的输入 x_train 和标签 y_train,测试集的输入 x_test 和 y_test。
1.3 逐层搭建网络结构
model = tf.keras.models.Sequential()
model = tf.keras.models.Sequential()
Sequential函数是一个容器,描述了神经网络的网络结构,在 Sequential 函数的输入参数中描述从输入层到输出层的网络结构。
网络结构示例:
# 拉直层 tf.keras.layers.Flatten()
# 全连接层 tf.keras.layers.Dense( 神经元个数, activation=”激活函数”, kernel_regularizer=”正则化方式”) # 可选参数 # activation: relu、softmax、sigmoid、tanh 等 # kernel_regularizer: tf.keras.regularizers.l1()、tf.keras.regularizers.l2()
# 卷积层 tf.keras.layers.Conv2D( filter = 卷积核个数, kernel_size = 卷积核尺寸, strides = 卷积步长, padding = “valid” or “same”)
用于正则化的 范数。
1.4 配置神经网络的训练方法
告知训练时使用的优化器、损失函数和准确率评测标准。
model.compile( optimizer = 优化器, loss = 损失函数, metrics = [“准确率”])
optimizer 可以是字符串形式给出的优化器名字,也可以是函数形式,使用函数形式可以设置学习率、动量和超参数。详细可参考:神经网络中的优化器
‘sgd’or tf.optimizers.SGD( lr=学习率, decay=学习率衰减率, momentum=动量参数) ‘adagrad’or tf.keras.optimizers.Adagrad(lr=学习率, decay=学习率衰减率) ‘adadelta’or tf.keras.optimizers.Adadelta(lr=学习率, decay=学习率衰减率) ‘adam’or tf.keras.optimizers.Adam (lr=学习率, decay=学习率衰减率)
loss 可以是字符串形式给出的损失函数的名字,也可以是函数形式。详细可参考:神经网络中的损失函数
‘mse’or tf.keras.losses.MeanSquaredError() ‘sparse_categorical_crossentropy’ or tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False)
损失函数常需要经过 softmax 等函数将输出转化为概率分布的形式。from_logits 则用来标注该损失函数是否需要转换为概率的形式,取 False 时表示转化为概率分布,取 True 时表示没有转化为概率分布,直接输出。
metrics 标注网络评测指标。注意,metrics 可以为复数形式,即指定多个 metrics。
‘accuracy’:y_和 y 都是数值。
如 y_=[1] y=[1]。 ‘categorical_accuracy’:y_和 y 都是以独热码和概率分布表示。 如 y_=[0, 1, 0], y=[0.256, 0.695, 0.048]。 ‘sparse_ categorical_accuracy’:y_是以数值形式给出,y 是以独热码形式给出。 如 y_=[1],y=[0.256, 0.695, 0.048]。
1.5 执行训练过程
model.fit(训练集的输入特征, 训练集的标签,
batch_size,
epochs,
validation_data = (测试集的输入特征,测试集的标签),
validataion_split = 从训练集划分多少比例给测试集,
validation_freq = 测试的epoch间隔次数)
更多参数可见:model.fit()。
关于 batch_size、epochs、iteration 区别。
关于 训练集、验证集、测试集 概念。
1.6 打印网络结构和参数统计
model.summary()
1.7 class声明网络结构
使用 Sequential 可以快速搭建网络结构,但是如果网络包含跳连等其他复杂网络结构,Sequential 就无法表示了。这就需要使用 class 来声明网络结构。
class MyModel(Model): def __init__(self): super(MyModel, self).__init__() 初始化网络结构 def call(self, x): y = self.d1(x) return y
model = MyModel()
使用 class 类封装网络结构,如上所示是一个 class 模板,MyModel 表示声明的神经网络的名字,括号中的 Model 表示创建的类需要继承 tensorflow 库中的 Model 类。
__init__():定义所需网络结构块
call():写出前向传播。
2 iris 数据集代码示例
import tensorflow as tf from sklearn import datasets import numpy as np x_train = datasets.load_iris().data y_train = datasets.load_iris().target np.random.seed(116) np.random.shuffle(x_train) np.random.seed(116) np.random.shuffle(y_train) model = tf.keras.Sequential([ tf.keras.layers.Dense(3, activation='softmax', kernel_regularizer=tf.keras.regularizers.l2()) ]) model.compile(optimizer=tf.keras.optimizers.SGD(lr=0.1), loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False), metrics=[tf.keras.metrics.sparse_categorical_accuracy]) model.fit(x_train, y_train, batch_size=32, epochs=500, validation_freq=20, validation_split=0.2) model.summary()
如果我们使用 class 申明网络结构,可写成如下形式:
import tensorflow as tf from sklearn import datasets import numpy as np x_train = datasets.load_iris().data y_train = datasets.load_iris().target np.random.seed(116) np.random.shuffle(x_train) np.random.seed(116) np.random.shuffle(y_train) class IrisModel(tf.keras.Model): def __init__(self): super(IrisModel, self).__init__() self.d1 = tf.keras.layers.Dense(3, activation='softmax', kernel_regularizer=tf.keras.regularizers.l2()) def call(self, inputs, training=None, mask=None): y = self.d1(inputs) return y model = IrisModel() model.compile(optimizer=tf.keras.optimizers.SGD(lr=0.1), loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False), metrics=[tf.keras.metrics.sparse_categorical_accuracy]) model.fit(x_train, y_train, batch_size=32, epochs=10, validation_freq=2, validation_split=0.2) model.summary()
3 MNIST 数据集代码示例
MNIST 数据集是一个有七万张图片、28×28 像素的 0~9 手写数字数据集。其中六万张用于训练,一万张用于测试。
import tensorflow as tf mnist = tf.keras.datasets.mnist (x_train, y_train), (x_test, y_test) = mnist.load_data() x_train, x_test = x_train / 255.0, x_test / 255.0 model = tf.keras.models.Sequential([ tf.keras.layers.Flatten(), tf.keras.layers.Dense(128, activation=tf.keras.activations.relu), tf.keras.layers.Dense(10, activation=tf.keras.activations.softmax) ]) model.compile(optimizer=tf.keras.optimizers.Adam(), loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False), metrics=[tf.keras.metrics.sparse_categorical_accuracy]) model.fit(x_train, y_train, batch_size=32, epochs=5, validation_data=(x_test, y_test), validation_freq=1) model.summary()
使用 class 申明网络结构
import tensorflow as tf mnist = tf.keras.datasets.mnist (x_train, y_train), (x_test, y_test) = mnist.load_data() x_train, x_test = x_train / 255.0, x_test / 255.0 print(x_train.shape) class MnistModel(tf.keras.Model): def __init__(self): super(MnistModel, self).__init__() self.flatten = tf.keras.layers.Flatten() self.d1 = tf.keras.layers.Dense(128, activation=tf.keras.activations.relu) self.d2 = tf.keras.layers.Dense(10, activation=tf.keras.activations.softmax) def call(self, inputs, training=None, mask=None): x = self.flatten(inputs) x = self.d1(x) y = self.d2(x) return y model = MnistModel() model.compile(optimizer=tf.keras.optimizers.Adam(), loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False), metrics=[tf.keras.metrics.sparse_categorical_accuracy]) model.fit(x_train, y_train, batch_size=32, epochs=5, validation_data=(x_test, y_test), validation_freq=1) model.summary()
原文地址:https://www.cnblogs.com/sun-a/p/13382713.html
- ios学习——键盘的收起
- IOS学习7——cocoapod安装与使用教程
- 使用Docker环境快速搭建靶机环境
- Java标准I/O流编程一览笔录
- 十分钟学perl够用(客服MM都懂了)
- Java多线程并发编程一览笔录
- Tomcat6/7应用服务器-禁用RC4等弱密码套件
- mybaits3整合spring总结
- 如何使用Airgeddon找回WiFi密码
- 设计缺陷将导致亚马逊Echo变身成为监听设备
- Unity引擎与C#脚本简介
- Redis分布式缓存系统Lua脚本食用指引
- 基于复杂方案OWSAP CsrfGuard的CSRF安全解决方案(适配nginx + DWR)
- XMLHttpRequest对象如何兼容各浏览器使用?
- JavaScript 教程
- JavaScript 编辑工具
- JavaScript 与HTML
- JavaScript 与Java
- JavaScript 数据结构
- JavaScript 基本数据类型
- JavaScript 特殊数据类型
- JavaScript 运算符
- JavaScript typeof 运算符
- JavaScript 表达式
- JavaScript 类型转换
- JavaScript 基本语法
- JavaScript 注释
- Javascript 基本处理流程
- Javascript 选择结构
- Javascript if 语句
- Javascript if 语句的嵌套
- Javascript switch 语句
- Javascript 循环结构
- Javascript 循环结构实例
- Javascript 跳转语句
- Javascript 控制语句总结
- Javascript 函数介绍
- Javascript 函数的定义
- Javascript 函数调用
- Javascript 几种特殊的函数
- JavaScript 内置函数简介
- Javascript eval() 函数
- Javascript isFinite() 函数
- Javascript isNaN() 函数
- parseInt() 与 parseFloat()
- escape() 与 unescape()
- Javascript 字符串介绍
- Javascript length属性
- javascript 字符串函数
- Javascript 日期对象简介
- Javascript 日期对象用途
- Date 对象属性和方法
- Javascript 数组是什么
- Javascript 创建数组
- Javascript 数组赋值与取值
- Javascript 数组属性和方法