如何找到最优学习率?
经过了大量炼丹的同学都知道,超参数是一个非常玄乎的东西,比如batch size,学习率等,这些东西的设定并没有什么规律和原因,论文中设定的超参数一般都是靠经验决定的。但是超参数往往又特别重要,比如学习率,如果设置了一个太大的学习率,那么loss就爆了,设置的学习率太小,需要等待的时间就特别长,那么我们是否有一个科学的办法来决定我们的初始学习率呢?
在这篇文章中,我会讲一种非常简单却有效的方法来确定合理的初始学习率。
学习率的重要性
目前深度学习使用的都是非常简单的一阶收敛算法,梯度下降法,不管有多少自适应的优化算法,本质上都是对梯度下降法的各种变形,所以初始学习率对深层网络的收敛起着决定性的作用,下面就是梯度下降法的公式
这里
就是学习率,如果学习率太小,会导致网络loss下降非常慢,如果学习率太大,那么参数更新的幅度就非常大,就会导致网络收敛到局部最优点,或者loss直接开始增加,如下图所示。
学习率的选择策略在网络的训练过程中是不断在变化的,在刚开始的时候,参数比较随机,所以我们应该选择相对较大的学习率,这样loss下降更快;当训练一段时间之后,参数的更新就应该有更小的幅度,所以学习率一般会做衰减,衰减的方式也非常多,比如到一定的步数将学习率乘上0.1,也有指数衰减等。
这里我们关心的一个问题是初始学习率如何确定,当然有很多办法,一个比较笨的方法就是从0.0001开始尝试,然后用0.001,每个量级的学习率都去跑一下网络,然后观察一下loss的情况,选择一个相对合理的学习率,但是这种方法太耗时间了,能不能有一个更简单有效的办法呢?
一个简单的办法
Leslie N. Smith 在2015年的一篇论文“Cyclical Learning Rates for Training Neural Networks”(http://t.cn/R6fTeFn)中的3.3节描述了一个非常棒的方法来找初始学习率,同时推荐大家去看看这篇论文,有一些非常启发性的学习率设置想法。
这个方法在论文中是用来估计网络允许的最小学习率和最大学习率,我们也可以用来找我们的最优初始学习率,方法非常简单。首先我们设置一个非常小的初始学习率,比如1e-5,然后在每个batch之后都更新网络,同时增加学习率,统计每个batch计算出的loss。最后我们可以描绘出学习的变化曲线和loss的变化曲线,从中就能够发现最好的学习率。
下面就是随着迭代次数的增加,学习率不断增加的曲线,以及不同的学习率对应的loss的曲线。
从上面的图片可以看到,随着学习率由小不断变大的过程,网络的loss也会从一个相对大的位置变到一个较小的位置,同时又会增大,这也就对应于我们说的学习率太小,loss下降太慢,学习率太大,loss有可能反而增大的情况。从上面的图中我们就能够找到一个相对合理的初始学习率,0.1。
之所以上面的方法可以work,因为小的学习率对参数更新的影响相对于大的学习率来讲是非常小的,比如第一次迭代的时候学习率是1e-5,参数进行了更新,然后进入第二次迭代,学习率变成了5e-5,参数又进行了更新,那么这一次参数的更新可以看作是在最原始的参数上进行的,而之后的学习率更大,参数的更新幅度相对于前面来讲会更大,所以都可以看作是在原始的参数上进行更新的。正是因为这个原因,学习率设置要从小变到大,而如果学习率设置反过来,从大变到小,那么loss曲线就完全没有意义了。
实现
上面已经说明了算法的思想,说白了其实是非常简单的,就是不断地迭代,每次迭代学习率都不同,同时记录下来所有的loss,绘制成曲线就可以了。下面就是使用PyTorch实现的代码,因为在网络的迭代过程中学习率会不断地变化,而PyTorch的optim里面并没有把learning rate的接口暴露出来,导致显示修改学习率非常麻烦,所以我重新写了一个更加高层的包mxtorch(http://t.cn/RYiHSuy),借鉴了gluon的一些优点,在定义层的时候暴露初始化方法,支持tensorboard,同时增加了大量的model zoo,包括inceptionresnetv2,resnext等等,提供预训练权重,model zoo参考于Cadene的repo(http://t.cn/RlWbF5k)。目前这个repo刚刚开始,欢迎有兴趣的小伙伴加入我。
下面就是部分代码,近期会把找学习率的代码合并到mxtorch中。这里使用的数据集是kaggle上的dog breed,使用预训练的resnet50,ScheduledOptim的源码点这里(http://t.cn/RYiHHKA)。
criterion = torch.nn.CrossEntropyLoss()net = model_zoo.resnet50(pretrained=True)net.fc = nn.Linear(2048, 120)with torch.cuda.device(0):
net = net.cuda()basic_optim = torch.optim.SGD(net.parameters(), lr=1e-5)optimizer = ScheduledOptim(basic_optim)lr_mult = (1 / 1e-5) ** (1 / 100)lr = []losses = []best_loss = 1e9for data, label in train_data:
with torch.cuda.device(0):
data = Variable(data.cuda())
label = Variable(label.cuda())
# forward
out = net(data)
loss = criterion(out, label)
# backward
optimizer.zero_grad()
loss.backward()
optimizer.step()
lr.append(optimizer.learning_rate)
losses.append(loss.data[0])
optimizer.set_learning_rate(optimizer.learning_rate * lr_mult)
if loss.data[0] < best_loss:
best_loss = loss.data[0]
if loss.data[0] > 4 * best_loss or optimizer.learning_rate > 1.:
breakplt.figure()plt.xticks(np.log([1e-5, 1e-4, 1e-3, 1e-2, 1e-1, 1]), (1e-5, 1e-4, 1e-3, 1e-2, 1e-1, 1))plt.xlabel('learning rate')plt.ylabel('loss')plt.plot(np.log(lr), losses)plt.show()plt.figure()plt.xlabel('num iterations')plt.ylabel('learning rate')plt.plot(lr)
one more thing
通过上面的例子我们能够有一个非常有效的方法寻找初始学习率,同时在我们的认知中,学习率的策略都是不断地做decay,而上面的论文别出心裁,提出了一种循环变化学习率的思想,能够更快的达到最优解,非常具有启发性,推荐大家去阅读阅读。
(完)
- Oracle中的PUBLIC(r10笔记第14天)
- Data Guard高级玩法:通过闪回恢复switchover主库 (r10笔记第13天)
- WinForm/MIS项目开发之中按钮级权限实践
- 恢复控制文件避免使用resetlogs选项 (r10笔记第12天)
- Go实现短url项目
- 【Go 语言社区】GO语言多核并行化的问题
- mysql执行计划看是否最优
- 通过IP定位区域的SQL优化思路(r10笔记第10天)
- Java基础-day06-知识点回顾与练习
- 【Go 语言社区】Golang语言的多核并行化例子
- 一条SQL语句的执行计划变化探究(r10笔记第9天)
- 【Go 语言社区】Web 通信 之 长连接、长轮询(long polling)--转
- Dubbo入门-协议;注册中心
- Oracle 12c PDB浅析(二)(r8笔记第29天)
- JavaScript 教程
- JavaScript 编辑工具
- JavaScript 与HTML
- JavaScript 与Java
- JavaScript 数据结构
- JavaScript 基本数据类型
- JavaScript 特殊数据类型
- JavaScript 运算符
- JavaScript typeof 运算符
- JavaScript 表达式
- JavaScript 类型转换
- JavaScript 基本语法
- JavaScript 注释
- Javascript 基本处理流程
- Javascript 选择结构
- Javascript if 语句
- Javascript if 语句的嵌套
- Javascript switch 语句
- Javascript 循环结构
- Javascript 循环结构实例
- Javascript 跳转语句
- Javascript 控制语句总结
- Javascript 函数介绍
- Javascript 函数的定义
- Javascript 函数调用
- Javascript 几种特殊的函数
- JavaScript 内置函数简介
- Javascript eval() 函数
- Javascript isFinite() 函数
- Javascript isNaN() 函数
- parseInt() 与 parseFloat()
- escape() 与 unescape()
- Javascript 字符串介绍
- Javascript length属性
- javascript 字符串函数
- Javascript 日期对象简介
- Javascript 日期对象用途
- Date 对象属性和方法
- Javascript 数组是什么
- Javascript 创建数组
- Javascript 数组赋值与取值
- Javascript 数组属性和方法
- R语言无监督学习:PCA主成分分析可视化
- 如何用r语言制作交互可视化报告图表
- R语言大数据分析纽约市的311万条投诉统计可视化与时间序列分析
- R语言动态可视化:制作历史全球平均温度的累积动态折线图动画gif视频图
- R语言里的非线性模型:多项式回归、局部样条、平滑样条、广义加性模型分析
- 使用R语言进行机制检测的隐马尔可夫模型HMM
- 【Kubernetes】Octant再探...
- 聊聊claudb的SlaveReplication
- 深度学习trick--labelsmooth
- Java锁的那些事儿
- React Hooks踩坑分享
- Python 自动化,Helium 凭什么取代 Selenium?
- Explain详解与索引最佳实践
- 使用SAP Analysis Path Framework (APF)展示CDS view数据
- 基于docker封装prometheus解决时区问题