PyTorch 60分钟入门系列之自动求导
Autograd:自动求导
在PyTorch中所有神经网络的核心是autograd
软件包。我们先来简单介绍一下这个,然后再构建第一个神经网络。
autograd
包为Tensors上的所有操作提供了自动求导。它是一个运行过程中定义的框架(define-by-run),这意味着反向传播是由代码的运行方式来定义的,并且每一次迭代都可能不同。
张量(Tensor)->0.4版本前是Variable
torch.Tensor
是包的中心类。如果你将属性.requires_grad
设置为True
,它将开始追踪所有的操作。当你完成了计算过程,你可以调用.backward()
,之后所有的梯度计算都是自动的。Tensor
的梯度将累积到.grad
属性中。
要停止跟踪历史记录的Tensor
,可以调用.detach()
将其从计算历史记录中分离出来,并防止跟踪将来的计算。
为了防止跟踪历史记录(和使用内存),你也可以用torch.no_grad()
包装代码块。 这在评估模型时特别有用,因为该模型可能具有require_grad = True
的可训练参数,但我们不需要梯度值。
还有一个类对于autograd
实现非常重要:一个Function
。
Tensor
和Function
是相互关联的,并建立一个非循环图,它编码构建了完整的计算过程。 每个变量都有一个.grad_fn
属性,该属性反应在已创建Tensor
的函数上(用户创建的Tensor
除外 - 它们的grad_fn
为None
)。
如果你想计算导数,可以在Tensor
上调用.backward()
。如果Tensor
是个标量(一个单元素数据),那么你不用为backward()
指定任何参数,然而如果它有多个元素,你需要指定一个gradient
参数,它是一个匹配尺寸的Tensor
。
import torch
x = torch.ones(2, 2, requires_grad=True) # 创建一个张量并设置`requires_grad = True`来跟踪计算
print(x) # 打印x的值
y = x + 2 # 对x张量进行计算操作
print(y) # 打印y值
print(y.grad_fn) # y是一个操作的结果,所以它有一个grad_fn。
print(y.requires_grad) # 打印y的requires_grad标志状态
z = y * y * 3 # 继续实现复杂的操作
out = z.mean() # 输出z的均值
print(z, out) # 打印计算输出结果
print(z.grad_fn)# y是一个操作的结果,所以它有一个grad_fn。
print(z.requires_grad) # 打印z的requires_grad标志状态
tensor([[ 1., 1.],
[ 1., 1.]])
tensor([[ 3., 3.],
[ 3., 3.]])
<AddBackward0 object at 0x7f181420a978>
True
tensor([[ 27., 27.],
[ 27., 27.]]) tensor(27.)
<MulBackward0 object at 0x7f180409e400>
True
.requires_grad_(...)
就地更改现有张量的requires_grad
标志。如果没有给出,函数输入标志默认为True
。需要注意的是:python 的默认参数,调用的时候,test( ) 与 test(True)等价跟内部flag默认值无关。从打印看,内部flag默认值是False,但是输出结果flag为True
a = torch.randn(2, 2) # 创建一个2*2的张量a
a = ((a * 3) / (a - 1))# 计算
print(a.requires_grad) # 打印a的requires_grad标志状态
a.requires_grad_(True) # 就地设置a的requires_grad标志状态
print(a.requires_grad) # 再次打印a的requires_grad标志状态
b = (a * a).sum() # 由a计算引入b
print(b.grad_fn) # b是一个操作的结果,所以它有一个grad_fn。
print(b.requires_grad) # 打印a的requires_grad标志状态
def test(x):
x = x*2
print(x.requires_grad) # False
return y
x = torch.randn(2, 2)
print(x.requires_grad) # False
y = test(x) # False
print(y.requires_grad) # True
False
True
<SumBackward0 object at 0x7f17b4c2d518>
True
False
False
True
梯度(Gradients)
让我们使用反向传播out.backward()
,它等同于out.backward(torch.Tensor(1)
)。
x = torch.ones(2, 2, requires_grad=True) # 创建一个张量并设置`requires_grad = True`来跟踪计算
y = x + 2 # 对x张量进行计算操作
z = y * y * 3 # 继续实现复杂的操作
out = z.mean() # 输出z的均值
out.backward() # 实现反向传播
print(x.grad) # 打印梯度 d(out)/dx
tensor([[ 4.5000, 4.5000],
[ 4.5000, 4.5000]])
4.5矩阵的计算过程如下所示:
我们还可以使用autograd做一些疯狂的事情!
x = torch.randn(3, requires_grad=True)
print(x)
y = x * 2
print(type(y))
print(type(y.data))
print(y.data.norm())
while y.data.norm() < 1000:
y = y * 2
print(y)
gradients = torch.tensor([0.1, 1.0, 0.0001], dtype=torch.float)
y.backward(gradients) # 沿着某方向的梯度
print(x.grad)
tensor([-0.3905, 1.3533, 1.0339])
<class 'torch.Tensor'>
<class 'torch.Tensor'>
tensor(3.4944)
tensor([ -399.9199, 1385.7303, 1058.7094])
tensor([ 102.4000, 1024.0000, 0.1024])
我们还可以通过使用torch.no_grad()
包装代码块来停止autograd
跟踪在张量上的历史记录,其中require_grad = True
:
print(x.requires_grad)
print((x ** 2).requires_grad)
with torch.no_grad():
print((x ** 2).requires_grad)
True
True
False
参考
Deep Learning with PyTorch: A 60 Minute Blitz(https://pytorch.org/tutorials/beginner/deep_learning_60min_blitz.html)
- Idea 常用快捷键
- silverlight中如何方便在多个"场景"即Xaml文件之间随意切换?
- 电子签名实现的思路、困难及解决方案
- JavaScript排序算法详解
- 事件处理需小心
- Mysql读写分离方案-MySQL Proxy环境部署记录
- Mysql读写分离方案-Amoeba环境部署记录
- linux系统终端命令提示符设置(PS1)记录
- 从MapX到MapXtreme2004[10]-根据zoom值修改显示范围
- Linq to Sql中Single写法不当可能引起的数据库查询性能低下
- 获得定长字符串
- vue2.0知识点汇总
- ie6,ie7,ff 的css兼容hack写法
- 使用子查询时应当注意的
- JavaScript 教程
- JavaScript 编辑工具
- JavaScript 与HTML
- JavaScript 与Java
- JavaScript 数据结构
- JavaScript 基本数据类型
- JavaScript 特殊数据类型
- JavaScript 运算符
- JavaScript typeof 运算符
- JavaScript 表达式
- JavaScript 类型转换
- JavaScript 基本语法
- JavaScript 注释
- Javascript 基本处理流程
- Javascript 选择结构
- Javascript if 语句
- Javascript if 语句的嵌套
- Javascript switch 语句
- Javascript 循环结构
- Javascript 循环结构实例
- Javascript 跳转语句
- Javascript 控制语句总结
- Javascript 函数介绍
- Javascript 函数的定义
- Javascript 函数调用
- Javascript 几种特殊的函数
- JavaScript 内置函数简介
- Javascript eval() 函数
- Javascript isFinite() 函数
- Javascript isNaN() 函数
- parseInt() 与 parseFloat()
- escape() 与 unescape()
- Javascript 字符串介绍
- Javascript length属性
- javascript 字符串函数
- Javascript 日期对象简介
- Javascript 日期对象用途
- Date 对象属性和方法
- Javascript 数组是什么
- Javascript 创建数组
- Javascript 数组赋值与取值
- Javascript 数组属性和方法
- ESP8266和ESP32配置(需使用ROS1和ROS2)
- SpringBoot+Gradle+ MyBatisPlus3.x搭建企业级的后台分离框架
- frp 内网穿透远程桌面(Windows 10)配置
- 你来讲讲AQS是什么吧?都是怎么用的?
- Angular单元测试里pipe的mock设计
- 亿级数据判断 bitmap-布隆过滤器
- centOS8 安装MySQL8(亲测)
- 聊一聊微信小程序包内容
- 全面分析 MySQL并发控制
- Flink History Server
- 几种定时任务(Timer、TimerTask、ScheduledFuture)的退出—结合真实案例【JAVA并发】
- gitlab内存消耗大,频繁出现502错误的解决办法
- Java基于POI实现excel任意多级联动下拉列表——支持从数据库查询出多级数据后直接生成【附源码】
- Elasticsearch 通过Scroll遍历索引,构造pandas dataframe 【Python多进程实现】
- 【Java】 NullPointerException、ArrayIndexOutOfBoundsException、ClassCastException、ArrayIndexOutOfBoundsE