PyTorch中Transformer模型的搭建
PyTorch
最近版本更新很快,1.2/1.3/1.4几乎是连着出,其中: 1.3/1.4版本主要是新增并完善了PyTorchMobile
移动端部署模块和模型量化模块。 而1.2版中一个重要的更新就是把加入了NLP领域中炙手可热的Transformer
模型,这里记录一下PyTorch
中Transformer
模型的用法(代码写于1.2版本,没有在1.3/1.4版本测试)。
1. 简介
也许是为了更方便地搭建Bert
,GPT-2
之类的NLP模型,PyTorch
将Transformer
相关的模型分为nn.TransformerEncoderLayer
、nn.TransformerDecoderLayer
、nn.LayerNorm
等几个部分。搭建模型的时候不一定都会用到, 比如fastai
中的Transformer
模型就只用到了encoder
部分,没有用到decoder
。
至于WordEmbedding
和PositionEncoding
两个部分需要自己另外实现。
WordEmbedding
可以直接使用PyTorch
自带的nn.Embedding
层。
PositionEncoding
层的花样就多了,不同的模型下面有不同的PositionEncoding
,比如Transformer的原始论文Attention is all you need
中使用的是无参数的PositionEncoding
, Bert
中使用的是带有学习参数的PositionEncoding
。
在本文中介绍的是参考Transformer
原始论文实现的Sequence2sequence
形式的Transformer
模型。
2. Sequence2sequence形式的Transformer模型搭建:
2.1 无可学习参数的PositionEncoding层
无参数的PositionEncoding
计算速度快,还可以减小整个模型的尺寸,据说在有些任务中,效果与有参数的接近。
class PositionalEncoding(nn.Module):
def __init__(self, d_model,dropout=0.1, max_len=5000):
super(PositionalEncoding, self).__init__()
self.dropout = nn.Dropout(p=dropout)
pe = torch.zeros(max_len, d_model)
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0).transpose(0, 1)
self.register_buffer('pe', pe)
def forward(self, x):
x = x + self.pe[:x.size(0), :]
return self.dropout(x)
2.2 有可学习参数的PositionEncoding层
我曾在一个序列预测任务(非NLP)里面对比过两种PositionEncoding
层,发现带有参数的PositionEncoding
层效果明显比没有参数的PositionEncoding
要好。
带参数的PositionEncoding
层的定义更为简单,直接继承一个nn.Embedding
,再续上一个dropout
就可以了。因为nn.Embedding
中包含了一个可以按索引取向量的权重矩阵weight
。
class LearnedPositionEncoding(nn.Embedding):
def __init__(self,d_model, dropout = 0.1,max_len = 5000):
super().__init__(max_len, d_model)
self.dropout = nn.Dropout(p = dropout)
def forward(self, x):
weight = self.weight.data.unsqueeze(1)
x = x + weight[:x.size(0),:]
return self.dropout(x)
2.3 Sequence2sequence模型
将embedding、position_encoding、encoder和decoder拼接起来,就可以构成一个完整的sequence2sequence形式的Transformer模型了。
class S2sTransformer(nn.Module):
def __init__(self,vocab_size,position_enc,d_model = 512,nhead = 8,num_encoder_layers=6,
num_decoder_layers=6,dim_feedforward=2048,dropout=0.1):
super(S2sTransformer,self).__init__()
# Preprocess
self.embedding = nn.Embedding(vocab_size,d_model)
self.pos_encoder_src = position_enc(d_model=512)
# tgt
self.pos_encoder_tgt = position_enc(d_model=512)
# Encoder
encoder_layer = nn.TransformerEncoderLayer(d_model,nhead,dim_feedforward,dropout)
encoder_norm = nn.LayerNorm(d_model)
self.encoder = nn.TransformerEncoder(encoder_layer,num_encoder_layers,encoder_norm)
# Decoder
decoder_layer = nn.TransformerDecoderLayer(d_model,nhead,dim_feedforward,dropout)
decoder_norm = nn.LayerNorm(d_model)
self.decoder = nn.TransformerDecoder(decoder_layer,num_decoder_layers,decoder_norm)
self.output_layer = nn.Linear(d_model,vocab_size)
self._reset_parameters()
self.d_model = d_model
self.nhead = nhead
def forward(self, src,tgt,src_mask = None,tgt_mask = None,
memory_mask = None,src_key_padding_mask = None,
tgt_key_padding_mask = None,memory_key_padding_mask = None):
# word embedding
src = self.embedding(src)
tgt = self.embedding(tgt)
# shape check
if src.size(1) != tgt.size(1):
raise RuntimeError("the batch number of src and tgt must be equal")
if src.size(2) != self.d_model or tgt.size(2) != self.d_model:
raise RuntimeError("the feature number of src and tgt must be equal to d_model")
# position encoding
src = self.pos_encoder_src(src)
tgt = self.pos_encoder_tgt(tgt)
memory = self.encoder(src, mask=src_mask, src_key_padding_mask=src_key_padding_mask)
output = self.decoder(tgt, memory, tgt_mask=tgt_mask, memory_mask=memory_mask,
tgt_key_padding_mask=tgt_key_padding_mask,
memory_key_padding_mask=memory_key_padding_mask)
output = self.output_layer(output)
# return output
return softmax(output,dim = 2)
def generate_square_subsequent_mask(self, sz):
r"""Generate a square mask for the sequence. The masked positions are filled with float('-inf').
Unmasked positions are filled with float(0.0).
"""
mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
return mask
def _reset_parameters(self):
r"""Initiate parameters in the transformer model."""
for p in self.parameters():
if p.dim() > 1:
xavier_uniform_(p)
模型搭建好了之后,就可以按照Sequence2sequence的训练方式进行训练了, 唯一需要注意的就是Transformer
的forward
过程是并行的,与基于RNN
的Sequence2sequence
模型稍有不同。
训练过程可以参考PyTorch
官网提供的chatbot
的教程
- 微信小程序调用接口返回数据或提交数据
- 巧用shell脚本生成快捷脚本(r2第12天)
- asp.net动态增加服务器端控件并提交表单
- c# asp.net 实现分页(pager)功能
- 一次数据库无法登陆的"问题"及排查(r2第11天)
- popcorn-js视频Video框架简单用法
- 一次数据库响应缓慢的问题排查(r2第9天)
- 通过Ajax方式上传文件(input file),使用FormData进行Ajax请求
- C# 读取指定文件夹下所有文件
- ASP.NET 实现Base64文件流下载PDF
- MVC自定义视图引擎地址
- JS禁止鼠标右键、禁止全选、复制、粘贴的方法(所谓的防盗功能)
- impdp异常中断导致的问题(r2第8天)
- 利用autocomplete.js实现仿搜索效果(ajax动态获取后端[C#]数据)
- JavaScript 教程
- JavaScript 编辑工具
- JavaScript 与HTML
- JavaScript 与Java
- JavaScript 数据结构
- JavaScript 基本数据类型
- JavaScript 特殊数据类型
- JavaScript 运算符
- JavaScript typeof 运算符
- JavaScript 表达式
- JavaScript 类型转换
- JavaScript 基本语法
- JavaScript 注释
- Javascript 基本处理流程
- Javascript 选择结构
- Javascript if 语句
- Javascript if 语句的嵌套
- Javascript switch 语句
- Javascript 循环结构
- Javascript 循环结构实例
- Javascript 跳转语句
- Javascript 控制语句总结
- Javascript 函数介绍
- Javascript 函数的定义
- Javascript 函数调用
- Javascript 几种特殊的函数
- JavaScript 内置函数简介
- Javascript eval() 函数
- Javascript isFinite() 函数
- Javascript isNaN() 函数
- parseInt() 与 parseFloat()
- escape() 与 unescape()
- Javascript 字符串介绍
- Javascript length属性
- javascript 字符串函数
- Javascript 日期对象简介
- Javascript 日期对象用途
- Date 对象属性和方法
- Javascript 数组是什么
- Javascript 创建数组
- Javascript 数组赋值与取值
- Javascript 数组属性和方法