PyTorch笔记--关于backward()
PyTorch会根据计算过程来自动生成动态图,然后可以根据动态图的创建过程进行反向传播,计算得到每个结点的梯度值。
为了能够记录张量的梯度,在创建张量的时候可以设置参数requires_grad = True,或者创建张量后调用requires_grad_()方法。
>>> x = torch.randn(2, 2, requires_grad = True)
>>> x = torch.randn(3, 3)
>>> x.requires_grad_()
同时由它计算得到的中间张量也会自动被设置成requires_grad = True,下面的程序中y = x2,y.requires_grad = True
>>> x = torch.randn(3, 3) >>> x.requires_grad_() tensor([[ 0.6734, -2.4904, 2.0093], [-0.2601, -0.3734, -1.5601], [ 0.9121, 0.3902, 1.0404]], requires_grad=True) >>> y = x.pow(2) # y = x*x >>> y.requires_grad True
这里可以使用反向传播来计算x的梯度值。需要注意的是y是一个张量,不是标量,并不能直接使用backward方法(pytorch不允许张量对张量求导)。
可以采取求和的方式将y变成一个标量。代码如下
>>> y = x.pow(2).sum() >>> y.backward() >>> x.grad tensor([[ 1.3468, -4.9807, 4.0187], [-0.5203, -0.7469, -3.1202], [ 1.8243, 0.7804, 2.0808]])
我的理解如下(可能是错的)
假设(这里使用1维的张量举例,没有完全和代码中的对应)
其中
就是直接执行下面一段程序的结果
>>> y = x.pow(2) # y = x*x >>> y tensor([[0.4534, 6.2019, 4.0375], [0.0677, 0.1394, 2.4339], [0.8320, 0.1523, 1.0824]], grad_fn=<PowBackward0>)
是对x的每一项单独求平方,最后得到的y是与x同样shape的张量。
如果对其求和,表达式就变成了
对每一项的平方求和后,显然y就是一个标量了。
>>> y = x.pow(2).sum() >>> y # 输出y的值 tensor(15.4006, grad_fn=<SumBackward0>)
重点在于上面提到的这个表达式
对xj求偏导,将求和公式展开可以知道,其他的项对求导的结果是不影响的(因为只有yj是xj的表达式),
所以使用标量y对x的每一项求导,结果依然是正确的。同时求和后y成为了一个标量,也符合了pytorch在语法规则上的要求。
所以是相当于我们在不影响求导结果的前提下,对原来的表达式做了适当的变换,使其计算结果成为了一个标量,再使用这个
变换了的表达式对张量里的每一项进行求导。
同时,pytorch也提供了另一种方法辅助我们能够完成张量对张量的求导。就是使用grad_tensors参数。
torch.autograd.backward( tensors, # 要计算导数的张量 torch.autograd.backward(tensor)和tensor.backward()作用是等价的 grad_tensors=None, # 在用非标量进行求导时需要使用该参数 retain_graph=None, # 保留计算图 create_graph=False, grad_variables=None)
使用grad_tensors参数:
- 它是非标量(y)进行求导时才使用
- 它的大小需要与张量x(y=f(x))的大小相同
- 它在每一个元素是全1是就是正常的求导
- 可以调整它的值来针对每一项在求导时占据的权重。
下面是使用该参数进行求导的示例。在文章最后给的第一个链接中对其有详细的解释,这里做一个整理。
>>> x.grad.zero_() # 梯度清零 tensor([[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]]) >>> grad_tensors = torch.ones_like(x) # 生成和x相同shape的全为1的张量 >>> y = x.pow(2) # y = x*x >>> y.backward(grad_tensors) # 将grad_tensors作为backward方法的输入 >>> x.grad # 张量x的梯度 tensor([[ 1.3468, -4.9807, 4.0187], [-0.5203, -0.7469, -3.1202], [ 1.8243, 0.7804, 2.0808]])
其中
>>> x.grad.zero_() # 梯度清零
是将原来计算的梯度清零,因为张量绑定的梯度张量在不清空的情况下会逐渐累积。
上面给出的代码可以看到,当我们建立了一个与x相同的shape且元素值都是1的grad_tensors张量后,并将其作为backward()方法的输入,
最终的求导结果也是正确的,它同样实现了张量对张量的求导。
>>> grad_tensors tensor([[1., 1., 1.], [1., 1., 1.], [1., 1., 1.]])
grad_tensors是如何起作用的呢?
它是将对y的求导转换为对y·grad_tensors的求导。点乘后的结果是一个标量,再用这个新的表达式进行求导。
所以当grad_tensors的值都是1时,本质上还是对每一项进行求和。
grad_tensors张量的值相当于对对应位置的项求导前加了个系数。示例如下
>>> grad_tensors=torch.tensor([[1.0, 1.0, 1.0], [0.1, 0.1, 0.1], [0.01, 0.01, 0.01]]) >>> y = x.pow(2) >>> x.grad.zero_() tensor([[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]]) >>> y.backward(grad_tensors) >>> x.grad tensor([[ 1.3468, -4.9807, 4.0187], [-0.0520, -0.0747, -0.3120], [ 0.0182, 0.0078, 0.0208]])
下面是grad_tensors的取值和最后的结果的对应:
>>> grad_tensors tensor([[1., 1., 1.], [1., 1., 1.], [1., 1., 1.]]) >>> x.grad tensor([[ 1.3468, -4.9807, 4.0187], [-0.5203, -0.7469, -3.1202], [ 1.8243, 0.7804, 2.0808]])
>>> grad_tensors tensor([[1.0000, 1.0000, 1.0000], [0.1000, 0.1000, 0.1000], [0.0100, 0.0100, 0.0100]]) >>> x.grad # 可以简单地理解为每一项乘了一个grad_tensors相同位置的值 tensor([[ 1.3468, -4.9807, 4.0187], [-0.0520, -0.0747, -0.3120], [ 0.0182, 0.0078, 0.0208]])
参考:
https://blog.csdn.net/qq_27825451/article/details/89393332
https://book.51cto.com/art/202103/650997.htm
https://www.cnblogs.com/marsggbo/p/11549631.html
《深入浅出PyTorch》张校捷
原文地址:https://www.cnblogs.com/xxmrecord/p/15130853.html
- JavaScript 教程
- JavaScript 编辑工具
- JavaScript 与HTML
- JavaScript 与Java
- JavaScript 数据结构
- JavaScript 基本数据类型
- JavaScript 特殊数据类型
- JavaScript 运算符
- JavaScript typeof 运算符
- JavaScript 表达式
- JavaScript 类型转换
- JavaScript 基本语法
- JavaScript 注释
- Javascript 基本处理流程
- Javascript 选择结构
- Javascript if 语句
- Javascript if 语句的嵌套
- Javascript switch 语句
- Javascript 循环结构
- Javascript 循环结构实例
- Javascript 跳转语句
- Javascript 控制语句总结
- Javascript 函数介绍
- Javascript 函数的定义
- Javascript 函数调用
- Javascript 几种特殊的函数
- JavaScript 内置函数简介
- Javascript eval() 函数
- Javascript isFinite() 函数
- Javascript isNaN() 函数
- parseInt() 与 parseFloat()
- escape() 与 unescape()
- Javascript 字符串介绍
- Javascript length属性
- javascript 字符串函数
- Javascript 日期对象简介
- Javascript 日期对象用途
- Date 对象属性和方法
- Javascript 数组是什么
- Javascript 创建数组
- Javascript 数组赋值与取值
- Javascript 数组属性和方法
- 微信小程序开发实战(14):音频组件(audio)
- dotnet OpenXML 文本 BodyProperties 的属性作用
- 微信小程序开发实战(15):视频组件(video)
- LeetCode 91,点赞和反对五五开,这题是好是坏由你来评判
- 打破国外垄断,开发中国人自己的编程语言(1):编写解析表达式的计算器
- 在CentOS8上编译安装开源EDA工具——Surelog
- 直播带货小程序源码中,商品详情页是如何获取html图片的
- LeetCode 90 | 经典递归问题,求出所有不重复的子集II
- 万字长文|Swift语法全面解析|附示例
- sshd服务搭建与管理
- Airflow Dag可视化管理编辑工具Airflow Console
- 使用 Clientset 获取 Kubernetes 资源对象
- Python爬虫 - 解决动态网页信息抓取问题
- Java内存故障?只是因为你不够帅!
- 线程池的execute方法和submit方法有什么区别?