白话Xavier | 神经网络初始化的工程选择
论文的链接在这里:https://machinelearning.wustl.edu/mlpapers/paper_files/AISTATS2010_GlorotB10.pdf
PyTorch代码
在介绍论文和理论之前,先讲一下如何使用在PyTorch中使用Xavier初始化:
def _initialize_weights(self):
# print(self.modules())
for m in self.modules():
print(m)
if isinstance(m, nn.Linear):
# print(m.weight.data.type())
# input()
# m.weight.data.fill_(1.0)
init.xavier_uniform_(m.weight, gain=1)
print(m.weight)
通俗讲理论
论文提出的Xavier的主要思想:每一层输出的方差应该尽量相等。
前向传播
下面进行推导:每一层的权重应该满足什么条件才能实现这个目标。
目前我们需要用到下面方差相关的定理:假设有随机变量x和w,他们都能服从均值为0,方差为
的分布,两者独立同分布,那么:
就会服从均值为0,方差为
的分布
服从均值为0,方差为
对于上面的内容不了解的可以看下面的推导:
这里用激活函数tanh来做讲解(论文中是用tanh激活函数来讲解的),下图中左图是tanh的形状,右边是tanh的导数的形状:
从上图可以看到,当x处于0附近的时候,导数接近1,因此这就是tanh激活函数的线性区域,也就是在x=0的附近,
假设我们的所有输入数据满足均值为0,方差为
的分布;参数w满足均值为0,方差为
。假设第一层是卷积层,卷积层的参数为n:
于是我们可以得到经过卷积层输出的结果:
这里面忽略偏执b。
因为我们假设输入x和权重w相互独立,所以可以得到输出z的方差:
为了更好的展示,我们把网络层的层号写在变量的上标处:
因此我们也可以得到:
如果这是一个k层的网络(包括卷积层和全链接层),可以得到:
继续展开,可以得到:
【我们来消化一下】首先为什么卷积层也可以用w*x这样全链接的线性形式呢?其实想一想,x是输入数据,不管是全连接还是图像数据,都可以是服从均值为0,方差为
的分布。为什么均值为0呢? 因为一般图像数据中会进行normalization,将均值置为0.
那么,卷积层的n是通道数乘上卷积核的尺寸,全连接层的n就是一层的神经元的数量。
为什么这里不考虑激活函数呢?其实是考虑了,之前提到的tanh在均值为0的附近,是相当于线性函数
的,所以上面的推导忽略了激活函数的部分。
我们继续往下走,从上式可以看出,后面的连乘是非常危险的,假设
总是大于1,这意味着,随着层数的越深,数据的方差会越来越大;当然如果小于1的话,层数越深,数值的方差就会越来越小。
回到这个公式:
如果想要达成一开始说的目标:每一层的权重应该满足什么条件才能实现这个目标,那么也就是
,应该满足:
推广到任意层:
目前为止,介绍了前向传播的情况。也就是如果要让每一层的数据的方差相等,需要满足:
反向传播
反向传播的原理也是一样的。
假设我们还是k层的网络,然后第k层的梯度是:
这里想象一下,每一层的一个数据,会对下一层的n个数据有连接,连接的权重就是
,这就是视野域。因此第k层的每一个数据反向传播的时候,也会受到k+1层中在视野域内n个数据的梯度的影响,因此可以得到:
假设每一层的数据的梯度数据都服从均值为0,方差为
的分布的话,那么可以得到下面的公式:
【这里是
还是
】
因为这里说的是k-1层的一个数据到底能对k层的多少数据产生影响。我分析了之后发现卷积和全连接层其实有略微的差别。全连接层中,k-1层的一个数据与k层的所有数据都有连接,所以这里是
,如果是卷积层的话,k-1层的一个数据只与k层中的
个数据连接,所以应该使用
。但是原文论文中似乎仅仅考虑了全连接层,所以这里使用的还是
。如果还是不明白可以画一画全连接的图再思考一下,类似这样的图。
继续往下走,我们假设有k层网络,对上面的公式不断展开,可以得到:
同样的连乘,为了让每一层的数据的梯度的方差相等,需要满足:
也就是:
与前向传播的形式大致相同。
现在我们为了让正向传播的数据的方差相同和反向传播的数据梯度的方差相同,得到了下面两个公式:
为了均衡考虑,所以最终的权重方差应该满足:
论文中依然使用的是均匀分布进行初始化参数,我们假设权重均匀分布初始化的范围为[-a,a],那么这个均匀分布的方差就是:
关于平均分布的均值和方差的推导可以看下面的步骤:
所以呢,我们最后可以得到:
因此,xavier的初始化方法,就是把参数初始化成下面范围内的均匀分布:
参考文章:
- https://zhuanlan.zhihu.com/p/220280792
- https://zhuanlan.zhihu.com/p/22044472
- https://blog.csdn.net/shuzfan/article/details/51338178
- https://www.cnblogs.com/hejunlin1992/p/8723816.html
- https://blog.csdn.net/CHS007chs/article/details/78133563
- https://zhuanlan.zhihu.com/p/27919794
- python常见模块之os模块
- BZOJ 2127: happiness(最小割解决集合划分)
- lightswitch 添加 TreeView 控件
- P3227 [HNOI2013]切糕
- python常见模块之random模块
- P2756 飞行员配对方案问题
- P1151 子数整数
- python常见模块之time模块
- U10783 名字被和谐了
- BZOJ 1174: [Balkan2007]Toponyms
- 1355: [Baltic2009]Radio Transmission
- Equation Group(方程式组织)
- Python中下划线---完全解读
- python常见模块之collections模块
- JavaScript 教程
- JavaScript 编辑工具
- JavaScript 与HTML
- JavaScript 与Java
- JavaScript 数据结构
- JavaScript 基本数据类型
- JavaScript 特殊数据类型
- JavaScript 运算符
- JavaScript typeof 运算符
- JavaScript 表达式
- JavaScript 类型转换
- JavaScript 基本语法
- JavaScript 注释
- Javascript 基本处理流程
- Javascript 选择结构
- Javascript if 语句
- Javascript if 语句的嵌套
- Javascript switch 语句
- Javascript 循环结构
- Javascript 循环结构实例
- Javascript 跳转语句
- Javascript 控制语句总结
- Javascript 函数介绍
- Javascript 函数的定义
- Javascript 函数调用
- Javascript 几种特殊的函数
- JavaScript 内置函数简介
- Javascript eval() 函数
- Javascript isFinite() 函数
- Javascript isNaN() 函数
- parseInt() 与 parseFloat()
- escape() 与 unescape()
- Javascript 字符串介绍
- Javascript length属性
- javascript 字符串函数
- Javascript 日期对象简介
- Javascript 日期对象用途
- Date 对象属性和方法
- Javascript 数组是什么
- Javascript 创建数组
- Javascript 数组赋值与取值
- Javascript 数组属性和方法
- R语言小数定律的保险业应用:泊松分布模拟索赔次数
- R语言中自编基尼系数的CART回归决策树的实现
- ArrayList的删除姿势你都掌握了吗
- sas神经网络:构建人工神经网络模型来识别垃圾邮件
- R语言多分类logistic逻辑回归模型在混合分布模拟单个风险损失值评估的应用
- 10分钟带你入门git到github
- 微服务[学成在线] day18:基于oauth2实现RBAC认证授权、微服务间认证实现
- 【TBase开源版测评】分布式事务全局一致性
- R语言进阶之主成分分析
- 二胖写参数校验的坎坷之路
- 图像倾斜校正算法的MATLAB实现:图像倾斜角检测及校正
- R语言时间序列数据指数平滑法分析交互式动态可视化
- R语言进阶之图形的合并
- R语言广义线性模型索赔频率预测:过度分散、风险暴露数和树状图可视化
- 还在使用Future轮询获取结果吗?CompletionService快来了解下。