ICLR 2020 Mogrifier LSTM 解析
1. 简介
LSTM模型作为一种经典的RNN网络结构,常用于NLP任务当中。在本篇工作中,我们进一步拓展了原始LSTM模型。注意到原始LSTM中输入x和之前状态h_prev是完全独立的,可能导致上下文信息的流失。我们提出一种形变LSTM,将输入x和之前状态h_prev进行交互,再输入进各个门里面运算。最后实验表明,改进过后的Mogrifier LSTM在各项任务均优于传统LSTM
2. 回顾传统LSTM
LSTM模型结构如下所示
LSTM模型结构
它一共有4个门控系统,分别是遗忘门,输入门,候选记忆细胞,输出门
各个门的计算公式如下
- 遗忘门:
- 输入门:
- 候选记忆细胞:
- 输出门:
- 记忆细胞:
- 新一轮的隐藏状态:
其中
代表的是sigmoid运算
各个门作用及机理如下
- 遗忘门:主要控制是否遗忘上一层的记忆细胞状态, 输入分别是当前时间步序列数据,上一时间步的隐藏状态,进行矩阵相乘,经sigmoid激活后,获得一个值域在[0, 1]的输出F,再跟上一层记忆细胞进行对应元素相乘,输出F中越接近0,代表需要遗忘上层记忆细胞的元素。
- 候选记忆细胞:这里的区别在于将sigmoid函数换成tanh激活函数,因此输出的值域在[-1, 1]。
- 输入门:与遗忘门类似,也是经过sigmoid激活后,获得一个值域在[0, 1]的输出。它用于控制当前输入X经过候选记忆细胞如何流入当前时间步的记忆细胞。 如果输入门输出接近为0,而遗忘门接近为1,则当前记忆细胞一直保存过去状态
- 输出门:也是通过sigmoid激活,获得一个值域在[0,1]的输出。主要控制记忆细胞到下一时间步隐藏状态的信息流动
相较于传统的RNN,LSTM引入了门机制,记忆细胞的设计使其能保存一定信息,在时间步进行传递,更好地捕捉时间序列较长的信息。而遗忘门的设计,更是能判断上一时刻信息是否对当前时刻产生影响,进而优化梯度流在整个网路的传递
但我们可以注意到,作为各个门的输入,X和隐藏状态H是完全独立的,这也是该研究的动机,如果输入前我能让X和隐藏状态H做交互,那性能是否能得到提升?
3. Mogrifier LSTM
Mogrifier LSTM引入以下两个公式
公式1
公式2
为了分别交互X 和 H,作者额外设置了两个矩阵Q,R
并且设定了一个超参数i,该参数分别控制X和H应该如何进行交互计算
当
,整个模型就退化成原始的LSTM
最后乘以一个常数2,这是因为经过sigmoid运算后,其值分布在(0, 1),这样反复乘下去,值是会越来越小的。因此乘以一个2保证其数值的稳定性。
4. 实验
我们来简单看下实验结果
实验结果
经过简单的改进,Mogrifier LSTM在各数据集上的表现均好于传统的LSTM
此外作者还探索了Mogrify中的超参数
设置,对模型性能的影响
超参数i对模型性能的影响
文中也对Mogrify这种交互方式给了相应的示意图
Mogrify交互方式示意图
5. 代码解析
作者也开源了相关代码在github上:https://github.com/RMichaelSwan/MogrifierLSTM
class MogLSTM(nn.Module):
def __init__(self, input_sz, hidden_sz, mog_iteration):
super(MogLSTM, self).__init__()
self.input_size = input_sz
self.hidden_size = hidden_sz
self.mog_iterations = mog_iteration
# 这里hiddensz乘4,是将四个门的张量运算都合并到一个矩阵当中,后续再通过张量分块给每个门
self.Wih = Parameter(torch.Tensor(input_sz, hidden_sz*4))
self.Whh = Parameter(torch.Tensor(hidden_sz, hidden_sz*4))
self.bih = Parameter(torch.Tensor(hidden_sz*4))
self.bhh = Parameter(torch.Tensor(hidden_sz*4))
# Mogrifiers
self.Q = Parameter(torch.Tensor(hidden_sz, input_sz))
self.R = Parameter(torch.Tensor(input_sz, hidden_sz))
self.init_weights()
def init_weights(self):
"""
权重初始化,对于W,Q,R使用xavier
对于偏置b则使用0初始化
:return:
"""
for p in self.parameters():
if p.data.ndimension() >= 2:
nn.init.xavier_uniform_(p.data)
else:
nn.init.zeros_(p.data)
def mogrify(self, xt, ht):
"""
计算mogrify
:param xt:
:param ht:
:return:
"""
for i in range(1, self.mog_iterations+1):
if(i % 2 == 0):
ht = (2*torch.sigmoid(xt @ self.R)*ht)
else:
xt = (2*torch.sigmoid(ht @ self.Q)*xt)
return xt, ht
def forward(self, x:torch.Tensor, init_states:Optional[Tuple[torch.Tensor, torch.Tensor]]=None) ->
Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
batch_sz, seq_sz, _ = x.size()
hidden_seq = []
if init_states is None:
ht = torch.zeros((batch_sz, self.hidden_size)).to(x.device)
Ct = torch.zeros((batch_sz, self.hidden_size)).to(x.device)
else:
ht, Ct = init_states
for t in range(seq_sz):
xt = x[:, t, :]
xt, ht = self.mogrify(xt, ht)
gates = (xt @ self.Wih + self.bih) + (ht @ self.Whh + self.bhh)
ingate, forgetgate, cellgate, outgate = gates.chunk(4, 1) # chunk方法将tensor分块
# LSTM
ft = torch.sigmoid(forgetgate)
it = torch.sigmoid(ingate)
Ct_candidate = torch.tanh(cellgate)
ot = torch.sigmoid(outgate)
# outputs
Ct = (ft*Ct) + (it*Ct_candidate)
ht = ot * torch.tanh(Ct)
hidden_seq.append(ht.unsqueeze(Dim.batch)) # unsqueeze是给指定位置加上维数为1的维度
hidden_seq = torch.cat(hidden_seq, dim=Dim.batch)
hidden_seq = hidden_seq.transpose(Dim.batch, Dim.seq).contiguous()
return hidden_seq, (ht, Ct)
- 首先输入分别表示 输入维度,隐层维度,Mogrify的计算次数(也就是前面提到的超参数i)
- 然后分别初始化 权重Wih, Whh,Bih,Bhh。注意这里要乘4,这是因为LSTM里面有4个门,它将其合并为一个矩阵运算,最后再分配给4个门,提高速度
- 同样也是随机初始化用于Mogrify计算的两个矩阵Q ,R
- init_weights是进行参数初始化
- 方法mogrify里面,就是mogrify计算的部分了,根据计算次数设定一个for循环,根据奇偶性,分别对X和H进行交互计算,并返回
- 在forward前向计算中,先对隐层和输入X做初始化。然后进行矩阵运算,通过chunks方法将张量分成4部分,分别给四个门,再根据我们的前面的公式,分别进行sigmoid和tanh计算。然后更新细胞状态和隐藏状态。将隐藏状态连结成序列,最终返回隐藏状态序列,隐藏状态和细胞状态
6. 总结
本文的动机还是比较朴素的,从现有的LSTM缺陷出发,创新的引出mogrify计算方式,将原本互相独立的X和H进行了交互运算。并通过实验探讨了超参数的设置,最后的实验也表明,改造过后的Mogrifier LSTM相较于传统LSTM有着不小的提升。
- go语言读取csv文件并输出的方法
- HDUOJ----3342Legal or Not
- go语言基本类型
- HDUOJ----2647Reward
- hduoj------确定比赛名次
- HDUOJ----1165Eddy's research II
- HDUOJ-----1556Color the ball
- HDUOJ-----2175取(m堆)石子游戏
- Golang语言社区-Go语言递归
- go语言mongdb管道使用(一)
- HDUOJ---------2255奔小康赚大钱
- HDUOJ------1711Number Sequence
- HDUOJ---1712 ACboy needs your help
- HDUOJ---1867 A + B for you again
- JavaScript 教程
- JavaScript 编辑工具
- JavaScript 与HTML
- JavaScript 与Java
- JavaScript 数据结构
- JavaScript 基本数据类型
- JavaScript 特殊数据类型
- JavaScript 运算符
- JavaScript typeof 运算符
- JavaScript 表达式
- JavaScript 类型转换
- JavaScript 基本语法
- JavaScript 注释
- Javascript 基本处理流程
- Javascript 选择结构
- Javascript if 语句
- Javascript if 语句的嵌套
- Javascript switch 语句
- Javascript 循环结构
- Javascript 循环结构实例
- Javascript 跳转语句
- Javascript 控制语句总结
- Javascript 函数介绍
- Javascript 函数的定义
- Javascript 函数调用
- Javascript 几种特殊的函数
- JavaScript 内置函数简介
- Javascript eval() 函数
- Javascript isFinite() 函数
- Javascript isNaN() 函数
- parseInt() 与 parseFloat()
- escape() 与 unescape()
- Javascript 字符串介绍
- Javascript length属性
- javascript 字符串函数
- Javascript 日期对象简介
- Javascript 日期对象用途
- Date 对象属性和方法
- Javascript 数组是什么
- Javascript 创建数组
- Javascript 数组赋值与取值
- Javascript 数组属性和方法
- Linux下如何查杀stopped进程详解
- Linux文件服务器实战详解(匿名用户)
- ubuntu16.0.4 设置固定ip地址的方法
- Linux文件服务器实战详解(虚拟用户)
- Linux CentOS下安装Tomcat9及web项目的部署
- Linux文件服务器实战详解(系统用户)
- 关于bash函数你可能不知道的一些事情(译)
- Linux Centos7系统端口占用问题的解决方法
- Linux中利用sudo进行赋权的方法详解
- Centos7下用户登录失败N次后锁定用户禁止登陆的方法
- Linux服务器被黑以后的详细处理步骤
- linux下用户程序同内核通信详解(netlink机制)
- yum安装本地rpm软件方案详解
- CentOS 部署 flask项目的方法
- 在linux服务器下使用版本控制软件SVN的方法