【动手学深度学习笔记】之通过丢弃法缓解过拟合问题
1. 通过丢弃法缓解过拟合问题
除了上一篇文章介绍的权重衰减法,深度学习常用的缓解过拟合问题的方法还有丢弃法。本文介绍倒置丢弃法及其实现。
1.1 丢弃法
丢弃法主要应用在含有隐藏层的模型中,所以我们以多层感知机为例,来说明丢弃法的实现方法。
当对这个多层感知机的隐藏层使用丢弃法时,该层的隐藏单元将有概率被丢弃(归零),有的概率会以做拉伸。丢弃概率是丢弃法的超参数。
由于对神经单元的丢弃是随机的,因此都有可能被清零,输出层的计算无法过度依赖中的任何一个。在训练过程中,丢弃法起到了正则化的作用,并可以用来缓解过拟合的问题。
但在测试过程中,一般不使用丢弃法。
1.2 实现丢弃法
1.2.1 原理实现丢弃法
对于一个单隐藏层,输入个数为4,隐藏层神经单元个数为5,激活函数为的多层感知机它的隐藏单元的计算表达式为
我们现在对这个隐藏层使用丢弃概率为的丢弃法。设随机变量为0和1的概率分别为和。得到新的隐藏单元计算表达式为
由于随机变量的期望,因此丢弃法对隐藏单元的输出期望没有影响。
1.2.2 程序实现丢弃法
def dropout(X,p):
X = X.float()
#断言判断输入的p是否满足大于0小于1
assert 0<=p<=1
kp = 1-p
#p=1,丢弃所有元素
if kp == 0:
return torch.zeros(X.shape)
#xi返回0或1。rand生成大于0小于1的随机数
xi = (torch.rand(X.shape)<kp).float()
return xi*X/kp
下面创建一个tensor测试一下dropout函数
X = torch.tensor([1,2,3,4])
dropout(X,0) #丢弃概率为0
dropout(X,0.5) #丢弃概率为0.5
dropout(X,0.9) #丢弃概率为0.9
测试得到的结果如下,证明我们的丢弃法函数实现了需要的效果
tensor([1., 2., 3., 4.])
tensor([2., 4., 0., 0.])
dropout(X,0.9)
1.3 应用丢弃法
应用丢弃法仅仅需要在搭建神经网络时,在应用丢弃法的层后面,调用Dropout层并指定丢弃概率即可。
下面以一个含有两个隐藏层的多层感知机为例,实现在神经网络中使用丢弃法。
1.3.1 定义和初始化模型
这里使用torch.nn模块中的Dropout函数,这个函数可以在训练时发挥作用,测试模型时,不发挥作用。
导入需要的库
import torch
from torch import nn
设置参数
num_inputs = 784
num_hiddens1 = 256
num_hiddens2 = 256
num_outputs = 10
p1,p2 = 0.5,0.5
#超参数,调节变量p1,p2可以看到丢弃法的效果
搭建神经网络
class FlattenLayer(nn.Moudle):
def __init__(self):
super(FlattenLayer,self).__init__()
def forward(self,x):
return x.view(x.shape[0],-1)
net = nn.Sequential(
FlattemLayer(),
nn.Linear(num_inputs,num_hiddens1),
nn.ReLU(),
nn.Dropout(p1),
nn.Linear(num_hiddens1,num_hiddens2),
nn.ReLU(),
nn.Dropout(p2),
nn.Linear(num_hiddens2,num_outputs)
)
初始化模型参数
for param in net.parameters():
nn.init.normal_(param,mean=0,std = 0.01)
看一下神经网络的结构
print(net)
1.3.2 损失函数、优化函数和读取数据
由于使用的本质还是softmax回归模型,因此使用softmax运算和交叉熵损失函数,这里直接使用PyTorch中的函数
loss = torch.nn.CrossEntropyLoss()
这里我们还是使用小批量随机梯度下降算法作为优化算法。
optimizer = torch.optoim.SGD(net.parameters(),lr = 0.5)
在开始训练模型之前,首先需要读取数据集(因为读取的还是Fashion-MNIST数据集,所以代码与之前的一样)。
batch_size = 256
mnist_train = torchvision.datasets.FashionMNIST(root='~/Datasets/FashionMNIST',train=True,download=True,transform=transforms.ToTensor())
#获取训练集
mnist_test = torchvision.datasets.FashionMNIST(root='~/Datasets/FashionMNIST',train=False,download=True,transform=transforms.ToTensor())
#获取测试集
#生成迭代器
train_iter = torch.utils.data.DataLoader(mnist_train,batch_size=batch_size,shuffle = True,num_workers = 0)
test_iter = torch.utils.data.DataLoader(mnist_test,batch_size = batch_size,shuffle=False,num_workers=0)
1.3.3 计算分类准确率
为了看出效果,需要计算出分类的准确率,在之前的推文写过了,这里不再赘述(详见softmax回归部分文章)。
def net_accurary(data_iter,net):
right_sum,n = 0.0,0
for X,y in data_iter:
#从迭代器data_iter中获取X和y
right_sum += (net(X).argmax(dim=1)==y).float().sum().item()
#计算准确判断的数量
n +=y.shape[0]
#通过shape[0]获取y的零维度(列)的元素数量
return right_sum/n
1.3.4 训练模型
def train(net,loss,optimizer,net_accuary,train_iter,test_iter,batch_size,num_epochs):
for epoch in range(num_epochs+1):
for X,y in train_iter:
l = loss(net(X),y).sum()
optimizer.zero_grad()
l.backward()
optimizer.step()
train_right = net_accuary(train_iter,net) #训练集准确率
test_right = net_accurary(test_iter, net) #测试集的准确率
print('第%d学习周期, 训练准确率%.3f, 测试准确率%.3f' % (epoch + 1, train_right, test_right))
1.4 丢弃法对训练效果的影响
由于Fashion-MNIST数据集较大,不容易产生过拟合现象,因此效果不是很明显。
- 设置对两个隐藏层的丢弃概率均为0时(不使用丢弃法),训练准确率和测试准确率如下
- 设置对两个隐藏层的丢弃概率都是0.5时,训练准确率和测试准确率如下
- 设置对两个隐藏层的丢弃概率都是0.2时,训练准确率和测试准确率如下
- 设置对两个隐藏层的丢弃概率分别为0.2和0.5时,训练准确率和测试准确率如下
- Eureka中RetryableClientQuarantineRefreshPercentage参数探秘
- Edgware.RC1中ZuulFallbackProvider的改进
- JPA的多表复杂查询:详细篇
- 尝试使用Memcached遇到的狗血问题
- Enumerable#Zip 实现一下
- 更新自己,不要影响其他人
- 【译】Spring官方教程:Spring Boot整合消息中间件RabbitMQ
- [实录]解决Migrator.Net 小bug
- Jenkins Pipeline插件十大最佳实践!
- Spring Cloud Hystrix的请求合并
- JQuery JCshare 0.1 分享插件
- Java中的即时编译(Just-in-time compilation)
- 无尽的忙碌换来幸福的日子
- 消费者驱动的微服务契约测试套件:Spring Cloud Contract
- JavaScript 教程
- JavaScript 编辑工具
- JavaScript 与HTML
- JavaScript 与Java
- JavaScript 数据结构
- JavaScript 基本数据类型
- JavaScript 特殊数据类型
- JavaScript 运算符
- JavaScript typeof 运算符
- JavaScript 表达式
- JavaScript 类型转换
- JavaScript 基本语法
- JavaScript 注释
- Javascript 基本处理流程
- Javascript 选择结构
- Javascript if 语句
- Javascript if 语句的嵌套
- Javascript switch 语句
- Javascript 循环结构
- Javascript 循环结构实例
- Javascript 跳转语句
- Javascript 控制语句总结
- Javascript 函数介绍
- Javascript 函数的定义
- Javascript 函数调用
- Javascript 几种特殊的函数
- JavaScript 内置函数简介
- Javascript eval() 函数
- Javascript isFinite() 函数
- Javascript isNaN() 函数
- parseInt() 与 parseFloat()
- escape() 与 unescape()
- Javascript 字符串介绍
- Javascript length属性
- javascript 字符串函数
- Javascript 日期对象简介
- Javascript 日期对象用途
- Date 对象属性和方法
- Javascript 数组是什么
- Javascript 创建数组
- Javascript 数组赋值与取值
- Javascript 数组属性和方法