对于小批量梯度下降以及如何配置批量大小的入门级介绍
随机梯度下降是训练深度学习模型的主要方法。
梯度下降有三种主要的的方法,具体使用哪一种要视情况而定。
在这篇文章中,你会了解一般情况下你该选择使用何种梯度下降,以及如何配置它。
读完这篇文章后,你会知道:
- 梯度下降是什么,它从高层次来看是怎样工作的。
- 批量,随机和小批量梯度下降分别是什么,每种方法的好处和局限性有哪些。
- 将小批量梯度下降作为指导方法,在您的应用程序上进行配置。
让我们开始吧。
对于小批量梯度下降以及如何配置批量大小的入门级介绍
照片来自Brian Smithson,保留相关权利。
教程概述
本教程分为3个部分; 他们是:
- 什么是梯度下降?
- 对比3种类型的梯度下降
- 如何配置小批量梯度下降
什么是梯度下降?
梯度下降是在寻找机器学习算法(例如人工神经网络和逻辑回归),的权重或系数时常用的优化算法。
它的工作原理是:让模型对训练数据进行预测并利用预测误差来修改模型,使得训练误差减小。
该算法的目标是找到使得模型在训练数据集上的误差最小的模型参数(例如系数或权重)。它将模型沿误差梯度或斜率下降的方向移动,直到最小误差值,与此同时更新模型。因此,将算法命名为“梯度下降”。
下面的伪代码总结了梯度下降算法:
model = initialization(...)
n_epochs = ...
train_data = ...
for i in n_epochs:
train_data = shuffle(train_data)
X, y = split(train_data)
predictions = predict(X, model)
error = calculate_error(y, predictions)
model = update_model(model, error)
欲了解更多信息,请参阅帖子:
- Gradient Descent For Machine Learning
- How to Implement Linear Regression with Stochastic Gradient Descent from Scratch with Python
对比3种类型的梯度下降
根据计算误差时使用的训练样本数量的不同,梯度下降表现为不同的形式,用来更新模型。
用于计算误差的模式数包括用于更新模型的梯度稳定程度。我们将看到梯度下降配置在计算效率和误差梯度的精确度上存在张力。
三种主要的梯度下降是批量,随机和小批量。
让我们仔细看看每种方式。
什么是随机梯度下降?
随机梯度下降(通常缩写为SGD)是梯度下降算法的变体,它根据训练数据集的每个例子计算误差并更新模型。
对每个训练样例更新模型意味着随机梯度下降通常被称为在线机器学习算法。
优点
- 频繁的即时更新使人可以深入的了解模型的性能和改进的速度。
- 这种梯度下降的变体可能是最容易理解和实现的,特别是对于初学者来说。
- 提高的模型更新频率可以加快对某些问题的学习。
- 噪声更新过程可以允许模型避免局部最小值(例如过早收敛)。
缺点
- 如此频繁地更新模型比其他梯度下降算法的计算代价更高,训练大型数据集时花费的时间显著增加。
- 频繁的更新可能会导致噪声梯度信号,这可能导致模型参数频繁波动,从而导致模型误差波动(在训练时期有更高的方差)。
- 噪声学习的过程中减小了误差梯度,也会使算法难以到达模型的最小误差。
什么是批量梯度下降?
批量梯度下降是梯度下降算法的一种变体,用于计算训练数据集中每个样例的误差,但仅在所有训练样例已经计算过后才更新模型。
我们把在整个训练数据集上运行的一次循环称为训练代。因此,通常说批量梯度下降在每代训练结束时进行模型更新。
优点
- 对模型更新较少意味着这种梯度下降的变体比随机梯度下降在计算上更加高效。
- 更新频率的降低带来了更稳定的误差梯度,并可能使得一些问题更稳定的收敛。
- 预测误差的计算和模型更新的分离使算法可以通过并行处理实现。
缺点
- 更稳定的误差梯度可能导致模型过早收敛到不太理想的一组参数。
- 训练结束时更新需要在所有训练样例中累积预测误差,引入了额外的复杂度。
- 通常,批量梯度下降实现的过程中,需要将整个训练数据集存在存储器中并且可供算法使用。
- 对于大型数据集,模型更新可能会变慢,进而使得训练速度可能会变得非常慢。
什么是小批量梯度下降?
小批量梯度下降是梯度下降算法的一种变体,它将训练数据集分成小批量,用于计算模型误差和更新模型系数。
实现过程中可以选择在小批量上对梯度进行求和,或者取梯度的平均值,这进一步降低了梯度的方差。
小批量梯度下降试图在随机梯度下降的稳健性和批梯度下降的效率之间寻求平衡。这是在深度学习领域中使用梯度下降时最常见的实现方式。
优点
- 模型更新频率高于批量梯度下降,允许更稳健的收敛,避免局部最小值。
- 分批更新比随机梯度下降的计算效率更高。
- 分批处理允许在存储器中只存储部分数据,算法的存储和实现都变得更高效。
缺点
- 小批量需要为学习算法配置额外的“小批量”超参数。
- 错误信息必须在批量梯度下降等小批量训练实例中累积。
如何配置小批量梯度下降
小批量梯度下降是大多数应用中梯度下降的推荐变体,特别是在深度学习中。
为了简洁起见,通常将小批量大小称为“批量大小”,它通常被调整到正在执行实现的计算体系结构的一个方面。例如两个符合GPU或CPU硬件(如32,64,128,256等)的内存要求的功率。
批量大小是学习过程中的一个滑块。
- 较小的值让学习过程在训练过程中迅速收敛,代价是会引入噪声。
- 较大的值给出一个缓慢收敛的学习过程,并精确估计误差梯度。
技巧1:32可能是一个好的批量大小的默认值。
... 【批量大小】通常选择在1到几百之间,例如:【批量大小】 = 32是一个很好的默认值,大于10的值发挥了矩阵-矩阵积对于矩阵-向量积提速优势。
技巧2:在调整批量大小时,查看不同批量大小的模型验证误差的学习曲线与训练时间是一个好主意。
...通过比较训练曲线(训练和验证误差与训练时间量),在其他超参数(除了学习率)被选定之后,可以与其他超参数分开进行优化。
技巧3:在调整完其他超参数后,调整批量大小和学习速率。
... 批量大小和学习率可能会与其他超参数稍有交互,所以两者都应该在最后重新优化。一旦选择了批量大小,一般可以固定,而其他超参数可以进一步优化(除了动量超参数)。
进一步阅读
如果您正在深入研究,本节将提供更多有关该主题的资源。
相关文章
附加阅读
- 随机梯度下降,维基百科
- 在线机器学习,维基百科
- 梯度下降优化算法的概述
- 深度架构梯度训练的实用建议,2012
- 随机优化的高效小批量训练,2014
- 在深度学习中,为什么我们不使用整个训练集来计算梯度?Quora
- 大型机器学习优化方法,2016
概要
在这篇文章中,你了解了梯度下降算法和你在实践中应该使用的变体。
具体来说,你了解到:
- 梯度下降是什么,它从高层次来看是怎样工作的。
- 批量,随机和小批量梯度下降分别是什么,每种方法的好处和局限性有哪些。
- 将小批量梯度下降作为指导方法,在您的应用程序上进行配置。
你有任何问题吗?
在下面的评论中提出您的问题,我会尽我所能来回答。
- JavaScript 教程
- JavaScript 编辑工具
- JavaScript 与HTML
- JavaScript 与Java
- JavaScript 数据结构
- JavaScript 基本数据类型
- JavaScript 特殊数据类型
- JavaScript 运算符
- JavaScript typeof 运算符
- JavaScript 表达式
- JavaScript 类型转换
- JavaScript 基本语法
- JavaScript 注释
- Javascript 基本处理流程
- Javascript 选择结构
- Javascript if 语句
- Javascript if 语句的嵌套
- Javascript switch 语句
- Javascript 循环结构
- Javascript 循环结构实例
- Javascript 跳转语句
- Javascript 控制语句总结
- Javascript 函数介绍
- Javascript 函数的定义
- Javascript 函数调用
- Javascript 几种特殊的函数
- JavaScript 内置函数简介
- Javascript eval() 函数
- Javascript isFinite() 函数
- Javascript isNaN() 函数
- parseInt() 与 parseFloat()
- escape() 与 unescape()
- Javascript 字符串介绍
- Javascript length属性
- javascript 字符串函数
- Javascript 日期对象简介
- Javascript 日期对象用途
- Date 对象属性和方法
- Javascript 数组是什么
- Javascript 创建数组
- Javascript 数组赋值与取值
- Javascript 数组属性和方法
- jQuery 介绍 以及基本使用
- 答应我,用了这个jupyter插件,别再重复造轮子了
- 商业数据分析从入门到入职(9)Python网络数据获取
- 谈一谈还原解包后小程序页面wxss样式的若干方法
- 什么?不使用selenium爬京东评论?你是不是在骗我
- 骚操作,Python操作PPT,你会吗?
- 用了这个jupyter插件,我已经半个月没打开过excel了
- Mística:一款支持任意协议的应用程序通信工具
- 为什么阿里巴巴禁止使用BigDecimal的equals方法做等值比较?
- 原创 | codefroces中的病毒,这题有很深的trick,你能解开吗?
- 原创 | git的远程分支是干啥的,和本地的有什么区别?
- 京东技术主导:全新架构的分布式事务Hmily 2.1.1发布
- iOS音视频接入-TRTC接入前期key、秘钥等准备
- 你一定不知道的 Linux 使用技巧
- 当 Python 爬虫搭配起 Bilibili 唧唧,奇怪的生产力出现了