PyTorch7:torch.nn.init
1. torch.nn.init 概述
因为神经网络的训练过程其实是寻找最优解的过程,所以神经元的初始值非常重要。
如果初始值恰好在最优解附近,神经网络的训练会非常简单。
而当神经网络的层数增加以后,一个突出的问题就是梯度消失和梯度爆炸。
梯度消失指的是由于梯度接近 0,导致神经元无法进行更新;
梯度爆炸指的是误差梯度在更新中累积得到一个非常大的梯度,这样的梯度会大幅度更新网络参数,进而导致网络不稳定。
torch.nn.init
模块提供了合理初始化初始值的方法。
一共提供了四类初始化方法:
- Xavier 分布初始化;
- Kaiming 分布初始化;
- 均匀分布、正态分布、常数分布初始化;
- 其它初始化。
有梯度边界的激活函数如 sigmoid
、tanh
和 softmax
等被称为饱和函数;
没有梯度边界的激活函数如 relu
被称为不饱和函数,它们对应的初始化方法不同。
2. 梯度消失和梯度爆炸
假设我们有一个 3 层的全连接网络:
对倒数第二层神经元的权重进行反向传播的公式为:
而 H3=H2*W3,所以
即 H2 ,即上一层的神经元的输出值,决定了W3的大小。
如果H2太大或太小,即梯度消失或梯度爆炸,将导致神经网络无法训练。
对于 sigmoid
和 tanh
等梯度绝对值小于 1 的激活函数来说,神经元的值会越来越小;
对于其它情况,假设我们构建了一个 100 层的全连接网络:
class MLP(nn.Module):
def __init__(self, neural_num, layers):
super(MLP, self).__init__()
self.linears = nn.ModuleList([nn.Linear(neural_num, neural_num, bias=False) for _ in range(layers)])
self.neural_num = neural_num
def forward(self, x):
for (i, linear) in enumerate(self.linears):
x = linear(x)
return x
def init(self):
for m in self.modules():
if isinstance(m, nn.Linear):
nn.init.normal_(m.weight.data)
layers=100
neural_num=256
batch_size=16
net = MLP(neural_num, layers)
net.init()
inputs = torch.randn(batch_size, neural_num)
output = net(inputs)
打印一下神经网络的输出:
>>> print(output)
tensor([[nan, nan, nan, ..., nan, nan, nan],
[nan, nan, nan, ..., nan, nan, nan],
[nan, nan, nan, ..., nan, nan, nan],
...,
[nan, nan, nan, ..., nan, nan, nan],
[nan, nan, nan, ..., nan, nan, nan],
[nan, nan, nan, ..., nan, nan, nan]], grad_fn=<MmBackward>)
可以看到,神经元的值都变成了 nan
。这是为什么呢?
因为方差可以表征数据的离散程度,让我们来打印一下每次神经元的值的方差:
layers: 0, std: 15.7603178024292
layers: 1, std: 253.5698699951172
layers: 2, std: 4018.8212890625
layers: 3, std: 64962.9453125
layers: 4, std: 1050192.125
layers: 5, std: 16682177.0
...
layers: 28, std: 8.295319341711625e+34
layers: 29, std: 1.2787049888311946e+36
layers: 30, std: 2.0164275976565801e+37
layers: 31, std: nan
output is nan at 31th layers
tensor([[ 1.3354e+38, -2.0165e+38, -3.2402e+37, ..., 1.0439e+37,
-inf, 1.2574e+38],
[ -inf, -inf, inf, ..., -inf,
-inf, inf],
[ 1.2230e+37, -inf, 5.6356e+37, ..., -1.2776e+38,
inf, -inf],
...,
[ 2.1591e+37, 2.5838e+38, -2.9146e+38, ..., inf,
-inf, -inf],
[ inf, 1.9056e+38, -inf, ..., inf,
-inf, -inf],
[ -inf, inf, -1.7735e+38, ..., 4.8110e+37,
inf, -inf]], grad_fn=<MmBackward>)
可以看到,到第 30 层的时候,神经元的值已经非常大或非常小,终于在第 31 层的时候,神经元的值突破了存储精度的极限,只好变成了 nan
。
我们知道,一组数的方差 D和期望 E在 X与 Y相互独立的条件下满足下面的性质:
所以有:
当 E(X)=0,E(Y)=0 的时候:
在神经网络中,由于全连接层的性质
得
因为Xi服从一个方差为 1 的正态分布,而Wi也服从一个方差为 1 的分布,所以D(H11)的值就是神经元的个数,因此标准差就是根号n 。而全连接的性质决定了第 k层的神经元的标准差为n的k次方再开根号 ,与上面例子中 256 个神经元的情况基本吻合。
为了让神经网络的神经元值稳定,我们希望将每一层神经元的方差维持在 1,这样每一次前向传播后的方差仍然是 1,使模型保持稳定。这被称为“方差一致性准则”。
因为
为了让 D(Hi)=1,我们只需要让
即
我们验证一下:
class MLP(nn.Module):
def __init__(self, neural_num, layers):
super(MLP, self).__init__()
self.linears = nn.ModuleList([nn.Linear(neural_num, neural_num, bias=False) for _ in range(layers)])
self.neural_num = neural_num
def forward(self, x):
for (i, linear) in enumerate(self.linears):
x = linear(x)
print(f'layers: {i}, std: {x.std()}')
if torch.isnan(x.std()):
print(f'output is nan at {i}th layers')
break
return x
def init(self):
for m in self.modules():
if isinstance(m, nn.Linear):
nn.init.normal_(m.weight.data, std=np.sqrt(1/self.neural_num))
layers=100
neural_num=256
batch_size=16
net = MLP(neural_num, layers)
net.init()
inputs = torch.randn(batch_size, neural_num)
output = net(inputs)
打印一下神经网络的神经元值:
layers: 0, std: 0.9983504414558411
layers: 1, std: 0.9868919253349304
layers: 2, std: 0.9728540778160095
layers: 3, std: 0.9823500514030457
layers: 4, std: 0.9672497510910034
layers: 5, std: 0.9902626276016235
...
layers: 95, std: 1.0507267713546753
layers: 96, std: 1.0782362222671509
layers: 97, std: 1.1384222507476807
layers: 98, std: 1.1450780630111694
layers: 99, std: 1.138461709022522
tensor([[-0.6622, 0.4439, 0.5704, ..., -2.2066, -1.1012, 0.0450],
[-0.1037, -0.3485, -0.0313, ..., -0.1562, -0.0520, 0.6481],
[ 0.3136, -0.0966, -1.5647, ..., -0.8760, -0.7498, 0.6339],
...,
[-0.6644, -0.4354, 0.8103, ..., 1.1510, 0.7699, 0.0607],
[-0.7511, -0.1086, 0.4008, ..., 1.5456, 0.6027, -0.0303],
[-0.5602, -0.1664, -0.9711, ..., -1.0884, -0.7040, 0.7415]],
grad_fn=<MmBackward>)
神经元的值果然是稳定的。
3. torch.nn.init.calculate_gain
这个函数计算激活函数之前和之后的方差的比例变化。比如 经过 rlue
以后还是 1,所以它的增益是 1。
PyTorch 给了常见的激活函数的变化增益:
激活函数 |
变化增益 |
---|---|
Linearity |
1 |
ConvND |
1 |
Sigmoid |
1 |
Tanh |
|
ReLU |
|
Leaky ReLU |
这个函数的参数如下:torch.nn.init.calculate_gain(nonlinearity, param=None)
-
nonlinearity
:激活函数; -
param
激活函数的参数。
4. Xavier initialization
为了解决饱和激活函数里的权重初始化问题,2010 年 Glorot 和 Bengio 发表了《Understanding the difficulty of training deep feedforward neural networks》论文,正式提出了 Xavier 初始化。Xavier 初始化通常使用均匀分布。由论文得,初始化后的张量中的值采样自 且
均匀分布下的 Xavier 初始化函数为 torch.nn.init.xavier_uniform_(tensor, gain=1)
。
Xavier 初始化也可以采用正态分布的方式,函数为 torch.nn.init.xavier_normal_(tensor, gain=1.0)
。其初始化后的张量中的值采样自 且
5. Kaiming initialization
2011 年 ReLU 函数横空出世,Xavier 初始化对 ReLU 函数不再适用。
2015 年,Kaiming He 提出了另一种初始化方法来适应 ReLU:
a
是 ReLU 上 时的斜率。同样的,Kaiming 初始化也有均匀分布和正态分布两种:
-
torch.nn.init.kaiming_uniform_(tensor, a=0, mode='fan_in', nonlinearity='leaky_relu')
:均匀分布的 Kaiming 初始化函数; -
torch.nn.init.kaiming_normal_(tensor, a=0, mode='fan_in', nonlinearity='leaky_relu')
:正态分布的 Kaiming 初始化函数。
6. 其它初始化方法
-
torch.nn.init.uniform_(tensor, a=0.0, b=1.0)
:初始化服从[a, b]
范围的均匀分布; -
torch.nn.init.normal_(tensor, mean=0.0, std=1.0)
:初始化服从mean=0
,std=1
时的正态分布; -
torch.nn.init.constant_(tensor, val)
:初始化为任一常数; -
torch.nn.init.ones_(tensor)
:初始化为 1; -
torch.nn.init.zeros_(tensor)
初始化为 0; -
torch.nn.init.eye_(tensor)
:初始化对角线为 1,其它为 0; -
torch.nn.init.orthogonal_(tensor, gain=1)
:对张量的矩形区域进行初始化。由于张量都是矩形,个人理解是这个函数会将整个张量进行初始化。 -
torch.nn.init.sparse_(tensor, sparsity, std=0.01)
:以sparsity
为概率将张量填充 0,剩余的元素的标准差为std
。
- b这样去设计 URL,可以提高网站的访问量
- 程序员必知之SEO
- 进程监控工具supervisor 启动Mongodb
- 祭奠那些年,我弃坑的开源轮子
- 这些奇技浮巧,助你优化前端应用性能
- Stepping.js——两步完成前后端分离架构设计
- 我的职业是前端工程师【十】客户端存储艺术:数据存储与模型
- 【开源】2md:将复制的内容、网页转成 markdown
- React Native 持续部署实践— push 代码构建出新版的 Growth
- 技巧 - 如何好一个 Git 提交信息及几种不同的规范
- React、Vue、Ember 及其他前端开发者,请暂缓更新到 Chrome 59 浏览器
- 微软开源全新的文档生成工具DocFX
- 使用 MimeKit 和 MailKit 发送邮件
- 使用 React Native 重写大型 Ionic 应用后,我们想分享一下这八个经验
- JavaScript 教程
- JavaScript 编辑工具
- JavaScript 与HTML
- JavaScript 与Java
- JavaScript 数据结构
- JavaScript 基本数据类型
- JavaScript 特殊数据类型
- JavaScript 运算符
- JavaScript typeof 运算符
- JavaScript 表达式
- JavaScript 类型转换
- JavaScript 基本语法
- JavaScript 注释
- Javascript 基本处理流程
- Javascript 选择结构
- Javascript if 语句
- Javascript if 语句的嵌套
- Javascript switch 语句
- Javascript 循环结构
- Javascript 循环结构实例
- Javascript 跳转语句
- Javascript 控制语句总结
- Javascript 函数介绍
- Javascript 函数的定义
- Javascript 函数调用
- Javascript 几种特殊的函数
- JavaScript 内置函数简介
- Javascript eval() 函数
- Javascript isFinite() 函数
- Javascript isNaN() 函数
- parseInt() 与 parseFloat()
- escape() 与 unescape()
- Javascript 字符串介绍
- Javascript length属性
- javascript 字符串函数
- Javascript 日期对象简介
- Javascript 日期对象用途
- Date 对象属性和方法
- Javascript 数组是什么
- Javascript 创建数组
- Javascript 数组赋值与取值
- Javascript 数组属性和方法
- 计算机基础知识总结与操作系统 PDF 下载
- Kafka工作流程及文件存储机制
- JS破解初探,折腾到头秃的美拍视频采集下载
- 去哪儿景点信息爬取并使用Django框架网页展示
- Kubernetes v1.15.3 升级到 v1.18.5 心得
- 结巴分词seo应用,Python jieba库基本用法及案例参考
- nali一个可以查询IP归属和CDN的命令
- 图片采集,python多线程采集头像图片源码附exe程序及资源包
- Python json数据爬取处理,红点官网大奖设计作品爬取
- 斗图狂魔必备沙雕表情包,python多线程爬取斗图啦表情图片
- 5个基本Linux命令行工具的现代化替代品
- Chrome 84 正式发布,支持私有方法、用户空闲检测!
- 类及数据库的应用,G-MARK网站数据Python爬虫系统的构建
- 获取素材图无忧,Pixabay图库网Python多线程采集下载
- Python关键词数据采集案例,5118查询网站关键词数据采集