深入Pytorch微分传参
时间:2019-10-18
本文章向大家介绍深入Pytorch微分传参,主要包括深入Pytorch微分传参使用实例、应用技巧、基本知识点总结和需要注意事项,具有一定的参考价值,需要的朋友可以参考一下。
导数
这段代码揭示了多个变量的微分以及如何求解loss为向量的导数
m1 = Variable(torch.ones((3,2)), requires_grad=True)
m2 = Variable(torch.ones((3,2))*2, requires_grad=True)
m3 = Variable(torch.ones((3,2))*4, requires_grad=True)
x1 = m1*m2
x2 = x1 *m3
y = x1 + x2
gradients= torch.ones((3,2))
y.backward(gradients)
print(f"m1 grad:{m1.grad}, \n m2 grad:{m2.grad}, \n m3 grad:{m3.grad}, \n x1 grad:{x1.grad}, \n x2 grad:{x2.grad}, \n y grad:{y.grad}")
深入导数--hook机制
hook机制的详细解释
这段代码解释了导数是如何自动计算保存的,
import torch
from torch.autograd import Variable
def register_hook(self, hook):
r"""Registers a backward hook.
The hook will be called every time a gradient with respect to the
Tensor is computed. The hook should have the following signature::
hook(grad) -> Tensor or None
The hook should not modify its argument, but it can optionally return
a new gradient which will be used in place of :attr:`grad`.
This function returns a handle with a method ``handle.remove()``
that removes the hook from the module.
Example::
>>> v = torch.tensor([0., 0., 0.], requires_grad=True)
>>> h = v.register_hook(lambda grad: grad * 2) # double the gradient
>>> v.backward(torch.tensor([1., 2., 3.]))
>>> v.grad
2
4
6
[torch.FloatTensor of size (3,)]
>>> h.remove() # removes the hook
"""
if not self.requires_grad:
raise RuntimeError("cannot register a hook on a tensor that "
"doesn't require gradient")
if self._backward_hooks is None:
self._backward_hooks = OrderedDict()
if self.grad_fn is not None:
self.grad_fn._register_hook_dict(self)
handle = hooks.RemovableHandle(self._backward_hooks)
self._backward_hooks[handle.id] = hook
return handle
v = Variable(torch.Tensor([2, 2, 2]), requires_grad=True)
h = v.register_hook(lambda grad: grad * grad) # double the gradient
v.backward(torch.Tensor([1, 1, 2]))
#先计算原始梯度,再进hook,获得一个新梯度。
print(v.grad.data)
# print(v.data)
# v.grad.data=torch.Tensor([0, 0, 0]) 梯度不置0就会根据hook自动累加
v.backward(torch.Tensor([1, 1, 1]))
print(v.grad.data)
# print(v.data)
h.remove() # removes the hook
使用with torch.no_grad()
with torch.no_grad():
train_l = loss(net(features, w, b), labels)
print('epoch %d, loss %f' % (epoch + 1, train_l.mean().numpy()))
SGD
这段代码揭示了一个最简单运用梯度下降的模型
import torch
from torch.autograd import Variable
from torch.distributions import normal
NUMBER = 100
# X = normal.Normal(loc = 0, scale = 1).sample((1, NUMBER))
X = torch.ones((1, NUMBER))*NUMBER
X= Variable(X, requires_grad=False)
b = torch.ones(X.shape[0])
b.requires_grad=True
epoch = 200
for i in range(epoch):
loss = torch.sum((b-X) ** 2)
b.grad = Variable(torch.zeros(X.shape[0])) #梯度置0
loss.backward()
b.data = b.data- b.grad * (1/NUMBER)/10
if not i%10:
print(f" {i} b is: {b}, b.grad is: {b.grad}")
原文地址:https://www.cnblogs.com/icodeworld/p/11700818.html
- C# 通过IEnumberable接口和IEnumerator接口实现自定义集合类型foreach功能
- 微信小程序之picker组件
- 详述 IntelliJ IDEA 中恢复代码的方法「进阶篇」
- mac环境下mongodb的安装和使用
- C# 终极基类Object介绍
- EF基础知识小记五(一对多、多对多处理)
- 字符串的方法汇总
- Kotlin和anko融合进行Android开发
- EF基础知识小记六(使用Code First建模自引用关系,常用于系统菜单、文件目录等有层级之分的实体)
- Vue.js简介
- Kotlin之Elvis 操作符
- C# 文件操作系列一
- Kotlin 包和 import 语句使用
- Redis-ML简介(第5部分)
- JavaScript 教程
- JavaScript 编辑工具
- JavaScript 与HTML
- JavaScript 与Java
- JavaScript 数据结构
- JavaScript 基本数据类型
- JavaScript 特殊数据类型
- JavaScript 运算符
- JavaScript typeof 运算符
- JavaScript 表达式
- JavaScript 类型转换
- JavaScript 基本语法
- JavaScript 注释
- Javascript 基本处理流程
- Javascript 选择结构
- Javascript if 语句
- Javascript if 语句的嵌套
- Javascript switch 语句
- Javascript 循环结构
- Javascript 循环结构实例
- Javascript 跳转语句
- Javascript 控制语句总结
- Javascript 函数介绍
- Javascript 函数的定义
- Javascript 函数调用
- Javascript 几种特殊的函数
- JavaScript 内置函数简介
- Javascript eval() 函数
- Javascript isFinite() 函数
- Javascript isNaN() 函数
- parseInt() 与 parseFloat()
- escape() 与 unescape()
- Javascript 字符串介绍
- Javascript length属性
- javascript 字符串函数
- Javascript 日期对象简介
- Javascript 日期对象用途
- Date 对象属性和方法
- Javascript 数组是什么
- Javascript 创建数组
- Javascript 数组赋值与取值
- Javascript 数组属性和方法
- 杭电的题,输出格式卡的很严。HDU 1716 排列2
- 移动端的(-webkit-linear-gradient -webkit-radial-gradient)
- ACM一年记,总结报告(希望自己可以走得很远)
- 移动端顺序问题上
- [USACO1.5]回文质数 Prime Palindromes
- 移动端上上(transform-translateZ注册)
- [USACO1.3]虫洞wormhole
- HTML--HTML入门篇(我想10分钟入门HTML,可以,交给我吧)
- 移动端初级知识点解析:translateZ translateY rotateY(上上上)
- new String() split详解
- XML--XML从入门到精通 Part 1 认识XML
- css的linear-gradient注意点
- css的linear-gradient
- 第十届山东省赛L题Median(floyd传递闭包)+ poj1975 (昨晚的课程总结错了,什么就出度出度,那应该是叫讨论一个元素与其余的关系)
- css中border-radius