扩展之Tensorflow2.0 | 20 TF2的eager模式与求导
【机器学习炼丹术】的学习笔记分享
参考目录:
- 1 什么是eager模式
- 2 TF1.0 vs TF2.0
- 3 获取导数/梯度
- 4 获取高阶导数
之前讲解了如何构建数据集,如何创建TFREC文件,如何构建模型,如何存储模型。这一篇文章主要讲解,TF2中提出的一个eager模式,这个模式大大简化了TF的复杂程度。
1 什么是eager模式
Eager模式(积极模式),我认为是TensorFlow2.0最大的更新,没有之一。
Tensorflow1.0的时候还是静态计算图,在《小白学PyTorch》系列的第一篇内容,就讲解了Tensorflow的静态特征图和PyTorch的动态特征图的区别。Tensorflow2.0提出了eager模式,在这个模式下,也支持了动态特征图的构建
不得不说,改的和PyTorch越来越像了,但是人类的工具总是向着简单易用的方向发展,这肯定是无可厚非的。
2 TF1.0 vs TF2.0
TF1.0中加入要计算梯度,是只能构建静态计算图的。
- 是先构建计算流程;
- 然后开始起一个会话对象;
- 把数据放到这个静态的数据图中。
整个流程非常的繁琐。
# 这个是tensorflow1.0的代码
import tensorflow as tf
a = tf.constant(3.0)
b = tf.placeholder(dtype = tf.float32)
c = tf.add(a,b)
sess = tf.Session() #创建会话对象
init = tf.global_variables_ini tializer()
sess.run(init) #初始化会话对象
feed = {
b: 2.0
} #对变量b赋值
c_res = sess.run(c, feed) #通过会话驱动计算图获取计算结果
print(c_res)
代码中,我们需要用palceholder先开辟一个内存空间,然后构建好静态计算图后,在把数据赋值到这个被开辟的内存中,然后再运行整个计算流程。
下面我们来看在eager模式下运行上面的代码
import tensorflow as tf
a = tf.Variable(2)
b = tf.Variable(20)
c = a + b
没错,这样的话,就已经完成一个动态计算图的构建,TF2是默认开启eager模式的,所以不需要要额外的设置了。这样的构建方法,和PyTorch是非常类似的。
3 获取导数/梯度
假如我们使用的是PyTorch,那么我们如何得到
的导数呢?
import torch
# Create tensors.
x = torch.tensor(10., requires_grad=True)
w = torch.tensor(2., requires_grad=True)
b = torch.tensor(3., requires_grad=True)
# Build a computational graph.
y = w * x + b # y = 2 * x + 3
# Compute gradients.
y.backward()
# Print out the gradients.
print(x.grad) # tensor(2.)
print(w.grad) # tensor(10.)
print(b.grad) # tensor(1.)
都没问题吧,下面用Tensorflow2.0来重写一下上面的内容:
import tensorflow as tf
x = tf.convert_to_tensor(10.)
w = tf.Variable(2.)
b = tf.Variable(3.)
with tf.GradientTape() as tape:
z = w * x + b
dz_dw = tape.gradient(z,w)
print(dz_dw)
>>> tf.Tensor(10.0, shape=(), dtype=float32)
我们需要注意这几点:
- 首先结果来看,没问题,w的梯度就是10;
- 对于参与计算梯度、也就是参与梯度下降的变量,是需要用
tf.Varaible
来定义的; - 不管是变量还是输入数据,都要求是浮点数float,如果是整数的话会报错,并且梯度计算输出None;
- tensorflow提供tf.GradientTape来实现自动求导,所以在tf.GradientTape内进行的操作,都会记录在tape当中,这个就是tape的概念。一个摄影带,把计算的过程录下来,然后进行求导操作
现在我们不仅要输出w的梯度,还要输出b的梯度,我们把上面的代码改成:
import tensorflow as tf
x = tf.convert_to_tensor(10.)
w = tf.Variable(2.)
b = tf.Variable(3.)
with tf.GradientTape() as tape:
z = w * x + b
dz_dw = tape.gradient(z,w)
dz_db = tape.gradient(z,b)
print(dz_dw)
print(dz_db)
运行结果为:
这个错误翻译过来就是一个non-persistent的录像带,只能被要求计算一次梯度。 我们用tape计算了w的梯度,然后这个tape清空了数据,所有我们不能再计算b的梯度。
解决方法也很简单,我们只要设置这个tape是persistent就行了:
import tensorflow as tf
x = tf.convert_to_tensor(10.)
w = tf.Variable(2.)
b = tf.Variable(3.)
with tf.GradientTape(persistent=True) as tape:
z = w * x + b
dz_dw = tape.gradient(z,w)
dz_db = tape.gradient(z,b)
print(dz_dw)
print(dz_db)
运行结果为:
4 获取高阶导数
import tensorflow as tf
x = tf.Variable(1.0)
with tf.GradientTape() as t1:
with tf.GradientTape() as t2:
y = x * x * x
dy_dx = t2.gradient(y, x)
print(dy_dx)
d2y_d2x = t1.gradient(dy_dx, x)
print(d2y_d2x)
>>> tf.Tensor(3.0, shape=(), dtype=float32)
>>> tf.Tensor(6.0, shape=(), dtype=float32)
想要得到二阶导数,就要使用两个tape,然后对一阶导数再求导就行了。
- END -
- [接口测试 - 基础篇] 06 好吧也来解析下html
- [接口测试 - 基础篇] 05 好讨厌的xml解析
- 【专知-关关的刷题日记17】Leetcode 268. Missing Number
- 【专知-关关的刷题日记18】Leetcode 35. Search Insert Position
- [接口测试 - http.client篇] 15 常用API说明及基本的示例
- [接口测试 - http.client篇] 14 源码初探及其工作机制分析
- 【专知-关关的刷题日记19】Leetcode 118. Pascal's Triangle
- 每周学点大数据 | No.3算法设计与分析理论
- HDU 1874 畅通工程续【Floyd算法实现】
- 接口测试 | 21 基于flask弄个restful API服务出来
- 数论部分第二节:埃拉托斯特尼筛法 埃拉托斯特尼筛法
- [接口测试 -基础篇] 20 用flask写一个简单server用于接口测试
- 接口测试 | urllib篇 19 urllib基本示例
- 接口测试 | urllib篇 18 urllib介绍
- JavaScript 教程
- JavaScript 编辑工具
- JavaScript 与HTML
- JavaScript 与Java
- JavaScript 数据结构
- JavaScript 基本数据类型
- JavaScript 特殊数据类型
- JavaScript 运算符
- JavaScript typeof 运算符
- JavaScript 表达式
- JavaScript 类型转换
- JavaScript 基本语法
- JavaScript 注释
- Javascript 基本处理流程
- Javascript 选择结构
- Javascript if 语句
- Javascript if 语句的嵌套
- Javascript switch 语句
- Javascript 循环结构
- Javascript 循环结构实例
- Javascript 跳转语句
- Javascript 控制语句总结
- Javascript 函数介绍
- Javascript 函数的定义
- Javascript 函数调用
- Javascript 几种特殊的函数
- JavaScript 内置函数简介
- Javascript eval() 函数
- Javascript isFinite() 函数
- Javascript isNaN() 函数
- parseInt() 与 parseFloat()
- escape() 与 unescape()
- Javascript 字符串介绍
- Javascript length属性
- javascript 字符串函数
- Javascript 日期对象简介
- Javascript 日期对象用途
- Date 对象属性和方法
- Javascript 数组是什么
- Javascript 创建数组
- Javascript 数组赋值与取值
- Javascript 数组属性和方法
- 【剑指Offer】栈的压入、弹出序列
- 【剑指Offer】包含min函数的栈
- 【剑指Offer】顺时针打印矩阵
- 从0打造属于自己的windows开发命令终端
- hadoop数据类型及自定义
- 惊!u202a错误,百分之九十都不知道的隐藏在文件路径里的惊天秘密!(干货收藏)
- 百度站点收录 - 什么叫自动推送
- 虚拟机安装Centos后的一些配置
- CentOS下的JDK安装
- python 技术篇-3行代码搞定图像文字识别,pytesseract库实现
- hadoop2.6.0完全分布式手动安装
- Python 库安装问题:ModuleNotFoundError: No module named 'windows'. 解决方法
- Python各种文件删除函数的功能区分!
- Python 技术篇-轻松操作windows系统电脑鼠标指针移动、点击
- Typora Picgo自动使用图床上传图片