PyTorch 60分钟入门系列之神经网络
神经网络
神经网络可以使用torch.nn
包来构建。
前面的学习大致了解了autograd
,nn
依赖于autograd
来定义模型并进行求导。一个nn.Module
包含多个神经网络层,以及一个forward(input)
方法来返回output
。
例如,看看以下这个分类数字图像的网络:
LeNet
它是一个简单的前馈网络。它将输入逐步地传递给多个层,然后给出输出。 一个典型的神经网络训练过程如下:
- 定义一个拥有可学习参数(或权重)的神经网络
- 在输入数据集上进行迭代
- 在网络中处理输入数据
- 计算损失(输出离分类正确有多大距离)
- 梯度反向传播给网络的参数
- 更新网络的权重,通常使用一个简单的更新规则(SGD):
weight = weight + learning_rate * gradient
定义网络结构
import torch
import torch.nn as nn
import torch.nn.functional as F
class Net(nn.Module):
def __init__(self):
# nn.Module子类的函数必须在构造函数中执行父类的构造函数
# 等价于nn.Model.__init__(self)
super(Net, self).__init__()
# 一个图像输入通道(灰度图), 6 输出通道(6张FeatureMap), 5x5 卷积核
self.conv1 = nn.Conv2d(1, 6, 5)
# 定义卷积层:输入6张特征图,输出16张特征图,卷积核5x5
self.conv2 = nn.Conv2d(6, 16, 5)
# 定义全连接层:线性连接(y = Wx + b),16*5*5个节点连接到120个节点上
self.fc1 = nn.Linear(16 * 5 * 5, 120)
# 定义全连接层:线性连接(y = Wx + b),120个节点连接到84个节点上
self.fc2 = nn.Linear(120, 84)
# 定义全连接层:线性连接(y = Wx + b),84个节点连接到10个节点上
self.fc3 = nn.Linear(84, 10)
# 定义向前传播函数,并自动生成向后传播函数(autograd)
def forward(self, x):
# 输入x->conv1->relu->2x2窗口的最大池化->更新到x
x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))
# 输入x->conv2->relu->2x2窗口的最大池化->更新到x
x = F.max_pool2d(F.relu(self.conv2(x)), 2)
# view函数将张量x变形成一维向量形式,总特征数不变,为全连接层做准备
x = x.view(-1, self.num_flat_features(x))
# 输入x->fc1->relu,更新到x
x = F.relu(self.fc1(x))
# 输入x->fc2->relu,更新到x
x = F.relu(self.fc2(x))
# 输入x->fc3,更新到x
x = self.fc3(x)
return x
def num_flat_features(self, x):
# 除了批处理维度之外的所有维度。
size = x.size()[1:]
num_features = 1
for s in size:
num_features *= s
return num_features
net = Net()
print(net)
Net(
(conv1): Conv2d(1, 6, kernel_size=(5, 5), stride=(1, 1))
(conv2): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))
(fc1): Linear(in_features=400, out_features=120, bias=True)
(fc2): Linear(in_features=120, out_features=84, bias=True)
(fc3): Linear(in_features=84, out_features=10, bias=True)
)
在实现过程中只需要定义forward
函数,backward
函数(用来计算梯度)是使用autograd
自动定义的。并且可以在forward
中使用任意的Tensor
运算操作。
模型中可学习的参数是通过net.parameters()
返回的:
params = list(net.parameters())
print(len(params))
print(params[0].size()) # conv1's .weight
10
torch.Size([6, 1, 5, 5])
让我们尝试一个随机的32x32
输入!
注意:这个网络(LeNet)的预期输入大小是32x32
。要在MNIST
数据集上使用此网络,请将数据集中的图像调整为32x32
。
input = torch.randn(1, 1, 32, 32)
out = net(input)
print(out)
tensor([[-0.0819, 0.1214, 0.0144, -0.0429, 0.0046, 0.0520, -0.0673,
0.0878, -0.1724, -0.1151]])
将梯度缓冲区中所有的参数置0,并使用随机的梯度进行反向传播:
net.zero_grad()
out.backward(torch.randn(1, 10))
torch.nn
仅支持mini-batch
。整个的torch.nn
包仅支持小批量的数据,而不是一个单独的样本。例如,nn.Conv2d
应传入一个4D
的Tensor
,维度为(nSamples X nChannels X Height X Width
)。如果你有一个单独的样本,使用input.unsqueeze(0)
来添加一个伪批维度。
回顾:
-
torch.Tensor
一个支持autograd
操作(如backward()
)的多维数组 -
nn.Module
神经网络模块。封装参数的便捷方式,帮助者将它们移动到GPU,导出,加载等。 -
nn.Parameter
一种Tensor
,当给Module
赋值时自动注册一个参数。 -
autograd.Function
实现一个autograd
操作的forward
和backward
定义。每一个Tensor
操作,创建至少一个Function
节点,来连接那些创建Tensor
的函数,并且记录其历史。
在这里,我们涵盖了:
- 定义神经网络
- 处理输入并调用
backward
定义损失函数
一个损失函数以一个(output
, target
)对为输入,然后计算一个值用以估计输出结果离目标结果多远。
在nn的包里存在定义了多种损失函数。一个简单的损失函数:nn.MSELoss
它计算输出与目标的均方误差。
output = net(input)
target = torch.arange(1, 11) # 一个虚拟的目标
target = target.view(1, -1) # 使其形状与输出相同。
criterion = nn.MSELoss()
loss = criterion(output, target)
print(loss)
tensor(38.9289)
现在,如果使用其.grad_fn
属性反向追踪损失,您将看到一个如下所示的计算图:
input -> conv2d -> relu -> maxpool2d -> conv2d -> relu -> maxpool2d -> view -> linear -> relu -> linear -> relu -> linear -> MSELoss -> loss
因此,当我们调用loss.backward()
时,损失对应的整个图都被求导,并且图中所有的Tensor
都会带有累积了梯度的.grad
属性requres_grad=True
。
print(loss.grad_fn) # MSELoss
print(loss.grad_fn.next_functions[0][0]) # Linear
print(loss.grad_fn.next_functions[0][0].next_functions[0][0]) # ReLU
<MseLossBackward object at 0x000002AE1A0953C8>
<AddmmBackward object at 0x000002AE1A0954A8>
<ExpandBackward object at 0x000002AE1A0953C8>
反向传播
要进行反向传播,我们只需要调用loss.backward()
。注意:需要清除现有的梯度,否则梯度将累积到现有梯度。
现在我们将调用loss.backward()
,并看看conv1
偏置在backward
之前和之后的梯度变化。
net.zero_grad() # 清除现有的梯度
print('conv1.bias.grad before backward')
print(net.conv1.bias.grad) # 打印之前的梯度值
loss.backward()
print('conv1.bias.grad after backward')
print(net.conv1.bias.grad) # 打印反向传播之后的梯度值
conv1.bias.grad before backward
tensor([ 0., 0., 0., 0., 0., 0.])
conv1.bias.grad after backward
tensor([ 0.0383, 0.1029, 0.0044, 0.1332, 0.0659, -0.0402])
权值更新
实践中最简单的更新规则是随机梯度下降(SGD):
weight = weight - learning_rate * gradient
learning_rate = 0.01
for f in net.parameters():
f.data.sub_(f.grad.data * learning_rate)
然而,当使用神经网络时,希望使用各种不同的更新规则,例如SGD
,Nesterov-SGD
,Adam
,RMSProp
等等。为了实现这一点,Pytorch构建一个优化包:torch.optim
,来实现所有的方法。使用非常简单:
import torch.optim as optim
# 创建优化器
optimizer = optim.SGD(net.parameters(), lr=0.01)
# 在训练的循环迭代中使用
optimizer.zero_grad() # 清除现有的梯度
output = net(input)
loss = criterion(output, target)
loss.backward()
optimizer.step() # 更新值
注意:在观察梯度变化时,首先要通过
optimizer.zero_grad()
清除现有的梯度,否则梯度将累积到现有梯度。
参考
Deep Learning with PyTorch: A 60 Minute Blitz(https://pytorch.org/tutorials/beginner/deep_learning_60min_blitz.html)
- 移动商城第七篇【购物车增删改查、提交订单】
- Shiro入门这篇就够了【Shiro的基础知识、回顾URL拦截】
- OFTest(一):如何忽略一些字段在端口poll报文
- Shiro第二篇【授权、整合Spirng、过滤器】
- Ajax数据的爬取(淘女郎为例)
- 在IDEA中编写Spark的WordCount程序
- Shiro第三篇【授权过滤器、与ehcache整合、验证码、记住我】
- Spark核心RDD、什么是RDD、RDD的属性、创建RDD、RDD的依赖以及缓存、
- Caused by: java.net.ConnectException: Connection refused: master/192.168.3.129:7077
- java.util.zip.ZipException: invalid LOC header (bad signature)
- 递归就这么简单
- Activiti就是这么简单
- WebService就是这么简单
- eclipse中hadoop2.3.0环境部署及在eclipse中直接提交mapreduce任务
- JavaScript 教程
- JavaScript 编辑工具
- JavaScript 与HTML
- JavaScript 与Java
- JavaScript 数据结构
- JavaScript 基本数据类型
- JavaScript 特殊数据类型
- JavaScript 运算符
- JavaScript typeof 运算符
- JavaScript 表达式
- JavaScript 类型转换
- JavaScript 基本语法
- JavaScript 注释
- Javascript 基本处理流程
- Javascript 选择结构
- Javascript if 语句
- Javascript if 语句的嵌套
- Javascript switch 语句
- Javascript 循环结构
- Javascript 循环结构实例
- Javascript 跳转语句
- Javascript 控制语句总结
- Javascript 函数介绍
- Javascript 函数的定义
- Javascript 函数调用
- Javascript 几种特殊的函数
- JavaScript 内置函数简介
- Javascript eval() 函数
- Javascript isFinite() 函数
- Javascript isNaN() 函数
- parseInt() 与 parseFloat()
- escape() 与 unescape()
- Javascript 字符串介绍
- Javascript length属性
- javascript 字符串函数
- Javascript 日期对象简介
- Javascript 日期对象用途
- Date 对象属性和方法
- Javascript 数组是什么
- Javascript 创建数组
- Javascript 数组赋值与取值
- Javascript 数组属性和方法