Highway Networks
导读
本文讨论了深层神经网络训练困难的原因以及如何使用Highway Networks去解决深层神经网络训练的困难,并且在pytorch上实现了Highway Networks。
一 、Highway Networks 与 Deep Networks 的关系
深层神经网络相比于浅层神经网络具有更好的效果,在很多方面都已经取得了很好的效果,特别是在图像处理方面已经取得了很大的突破,然而,伴随着深度的增加,深层神经网络存在的问题也就越大,像大家所熟知的梯度消失问题,这也就造成了训练深层神经网络困难的难题。2015年由Rupesh Kumar Srivastava等人受到LSTM门机制的启发提出的网络结构(Highway Networks)很好的解决了训练深层神经网络的难题,Highway Networks 允许信息高速无阻碍的通过深层神经网络的各层,这样有效的减缓了梯度的问题,使深层神经网络不在仅仅具有浅层神经网络的效果。
二、Deep Networks 梯度消失/爆炸(vanishing and exploding gradient)问题
我们先来看一下简单的深层神经网络(仅仅几个隐层)
先把各个层的公式写出来
我们对W1求导:
W = W - lr * g(t)
以上公式仅仅是四个隐层的情况,当隐层的数量达到数十层甚至是数百层的情况下,一层一层的反向传播回去,当权值 < 1的时候,反向传播到某一层之后权值近乎不变,相当于输入x的映射,例如,g(t) =〖0.9〗^100已经是很小很小了,这就造成了只有前面几层能够正常的反向传播,后面的那些隐层仅仅相当于输入x的权重的映射,权重不进行更新。反过来,当权值 > 1的时候,会造成梯度爆炸,同样是仅仅前面的几层能更改正常学习,后面的隐层会变得很大。
三、Highway Networks Formula
Notation
(.) 操作代表的是矩阵按位相乘
sigmoid函数:
Highway Networks formula
对于我们普通的神经网络,用非线性激活函数H将输入的x转换成y,公式1忽略了bias。但是,H不仅仅局限于激活函数,也采用其他的形式,像convolutional和recurrent。
对于Highway Networks神经网络,增加了两个非线性转换层,一个是 T(transform gate) 和一个是 C(carry gate),通俗来讲,T表示输入信息经过convolutional或者是recurrent的信息被转换的部分,C表示的是原始输入信息x保留的部分 ,其中 T=sigmoid(wx + b)
为了计算方便,这里定义了 C = 1 - T
需要注意的是x,y, H, T的维度必须一致,要想保证其维度一致,可以采用sub-sampling
或者zero-padding
策略,也可以使用普通的线性层改变维度,使其一致。
几个公式相比,公式3要比公式1灵活的多,可以考虑一下特殊的情况,T= 0的时候,y = x,原始输入信息全部保留,不做任何的改变,T = 1的时候,Y = H,原始信息全部转换,不在保留原始信息,仅仅相当于一个普通的神经网络。
四、Highway BiLSTM Networks
Highway BiLSTM Networks Structure Diagram
下图是 Highway BiLSTM Networks 结构图: input:代表输入的词向量 B:在本任务代表bidirection lstm,代表公式(2)中的 H T:代表公式(2)中的 T,是Highway Networks中的transform gate C:代表公式(2)中的 C,是Highway Networks中的carry gate Layer = n,代表Highway Networks中的第n层 Highway:框出来的代表一层Highway Networks 在这个结构图中,Highway Networks第 n - 1 层的输出作为第n层的输入
Highway BiLSTM Networks Demo
pytorch搭建神经网络一般需要继承nn.Module
这个类,然后实现里面的forward()
函数,搭建Highway BiLSTM Networks写了两个类,并使用nn.ModuleList
将两个类联系起来:
在HBiLSTM
类的forward()
函数里面我们实现Highway BiLSTM Networks
的的公式
首先我们先来计算H,上文已经说过,H可以是卷积或者是LSTM,在这里,normal_fc
就是我们需要的H
上文提及,x,y,H,T的维度必须保持一致,并且提供了两种策略,这里我们使用一个普通的Linear
去转换维度
也可以采用zero-padding
的策略保证维度一致
维度一致之后我们就可以根据我们的公式来写代码了:
最后的information_flow
就是我们的输出,但是,还需要经过转换维度保证维度一致。
更多的请参考Github: Highway Networks implement in pytorch
[https://github.com/bamtercelboo/pytorch_Highway_Networks]
五、Highway BiLSTM Networks 实验结果
本次实验任务是使用Highway BiLSTM Networks 完成情感分类任务(一句话的态度分成积极或者是消极),数据来源于Twitter情感分类数据集,以下是数据集中的各个标签的句子个数:
下图是本次实验任务在2-class数据集中的测试结果。图中1-300在Highway BiLSTM Networks中表示Layer = 1,BiLSTM 隐层的维度是300维。
实验结果:从图中可以看出,简单的多层双向LSTM并没有带来情感分析性能的提升,尤其是是到了10层之后,效果有不如随机的猜测。当用上Highway Networks之后,虽然性能也在逐步的下降,但是下降的幅度有了明显的改善。
References
- Highway Networks(paper) https://arxiv.org/pdf/1505.00387.pdf
- Training Very Deep Networks
https://arxiv.org/pdf/1507.06228.pdf
- 为什么深层神经网络难以训练 http://blog.csdn.net/binchasing/article/details/50300069
- Training Very Deep Networks–Highway Networks http://blog.csdn.net/cv_family_z/article/details/50349436
- Very Deep Learning with Highway Networks http://people.idsia.ch/~rupesh/very_deep_learning/
- Hightway Networks学习笔记 http://blog.csdn.net/sinat_35218236/article/details/73826203?utm_source=itdadao&utm_medium=referral
- Go实现短url项目
- 【Go 语言社区】GO语言多核并行化的问题
- mysql执行计划看是否最优
- 通过IP定位区域的SQL优化思路(r10笔记第10天)
- Java基础-day06-知识点回顾与练习
- 【Go 语言社区】Golang语言的多核并行化例子
- 一条SQL语句的执行计划变化探究(r10笔记第9天)
- 【Go 语言社区】Web 通信 之 长连接、长轮询(long polling)--转
- Dubbo入门-协议;注册中心
- Oracle 12c PDB浅析(二)(r8笔记第29天)
- 【Go 语言社区】在 Go 语言中,如何正确的使用并发
- Data Guard高级玩法:通过闪回恢复failover备库 (r10笔记第7天)
- ajax跨域问题-web开发必会
- 在线重定义的补充测试(r10笔记第26天)
- JavaScript 教程
- JavaScript 编辑工具
- JavaScript 与HTML
- JavaScript 与Java
- JavaScript 数据结构
- JavaScript 基本数据类型
- JavaScript 特殊数据类型
- JavaScript 运算符
- JavaScript typeof 运算符
- JavaScript 表达式
- JavaScript 类型转换
- JavaScript 基本语法
- JavaScript 注释
- Javascript 基本处理流程
- Javascript 选择结构
- Javascript if 语句
- Javascript if 语句的嵌套
- Javascript switch 语句
- Javascript 循环结构
- Javascript 循环结构实例
- Javascript 跳转语句
- Javascript 控制语句总结
- Javascript 函数介绍
- Javascript 函数的定义
- Javascript 函数调用
- Javascript 几种特殊的函数
- JavaScript 内置函数简介
- Javascript eval() 函数
- Javascript isFinite() 函数
- Javascript isNaN() 函数
- parseInt() 与 parseFloat()
- escape() 与 unescape()
- Javascript 字符串介绍
- Javascript length属性
- javascript 字符串函数
- Javascript 日期对象简介
- Javascript 日期对象用途
- Date 对象属性和方法
- Javascript 数组是什么
- Javascript 创建数组
- Javascript 数组赋值与取值
- Javascript 数组属性和方法
- python实现线性回归之lasso回归
- 分页查询 offset 和 limit 和 limit 的区别
- mybatis文件映射之获取参数值时#和$的区别
- python实现线性回归之岭回归
- 操作系统实验之存储管理
- MySQL EXPLAIN 的使用
- mybatis文件映射之关联查询初探(一)
- python实现线性回归之弹性网回归
- 【原创】python倒排索引之查找包含某主题或单词的文件
- python实现逻辑回归
- Linux文件管理参考
- CloudBase Framework丨第一个 Deno 部署工具是如何打造的?
- 关于null通过+" ",String.ValueOf转换为字符串的问题!!!
- Java实现尺取法
- 【自然语言处理】利用朴素贝叶斯进行新闻分类(自己处理数据)