pytorch-- Attention Mechanism
时间:2019-11-16
本文章向大家介绍pytorch-- Attention Mechanism,主要包括pytorch-- Attention Mechanism使用实例、应用技巧、基本知识点总结和需要注意事项,具有一定的参考价值,需要的朋友可以参考一下。
1. paper: Learning Phrase Representations using RNN Encoder–Decoder for Statistical Machine Translation
Encoder
每个时刻输入一个词,隐藏层状态根据公式ht=f(ht−1,xt)改变。其中激活函数f可以是sigmod,tanh,ReLU,sotfplus等。
读完序列的每一个词之后,会得到一个固定长度向量c=tanh(VhN)
Decoder
由结构图可以看出,t时刻的隐藏层状态ht由ht−1,yt−1,c决定:ht=f(ht−1,yt−1,c),其中h0=tanh(V′c)
最后的输出yt是由ht,yt−1,c决定
P=(yt|yt−1,yt−2,...,y1,c)=g(ht,yt−1,c)
以上,f,gf,g都是激活函数,其中g一般是softmax
对此我在pytoch环境下进行实现seq2seq最初版的模型:
1 import numpy as np 2 import torch 3 import torch.nn as nn 4 from torch.autograd import Variable 5 6 dtype = torch.FloatTensor 7 # S: Symbol that shows starting of decoding input 8 # E: Symbol that shows ending of decoding output 9 # P: Symbol that will fill in blank sequence if current batch data size is short than time steps 10 11 char_arr = [c for c in 'SEPabcdefghijklmnopqrstuvwxyz'] 12 num_dic = {n: i for i, n in enumerate(char_arr)} 13 14 seq_data = [['man', 'women'], ['black', 'white'], ['king', 'queen'], ['girl', 'boy'], ['up', 'down'], ['high', 'low']] 15 16 # Seq2Seq Parameter 17 n_step = 5 18 n_hidden = 128 19 n_class = len(num_dic) #29 20 batch_size = len(seq_data) #6 21 22 def make_batch(seq_data): 23 input_batch, output_batch, target_batch = [], [], [] 24 25 for seq in seq_data: 26 for i in range(2): 27 seq[i] = seq[i] + 'P' * (n_step - len(seq[i])) 28 29 input = [num_dic[n] for n in seq[0]] 30 output = [num_dic[n] for n in ('S' + seq[1])] 31 target = [num_dic[n] for n in (seq[1] + 'E')] 32 33 input_batch.append(np.eye(n_class)[input]) 34 output_batch.append(np.eye(n_class)[output]) 35 target_batch.append(target) # not one-hot 36 37 # make tensor 38 return Variable(torch.Tensor(input_batch)), Variable(torch.Tensor(output_batch)), Variable(torch.LongTensor(target_batch)) 39 40 # Model 41 class Seq2Seq(nn.Module): 42 def __init__(self): 43 super(Seq2Seq, self).__init__() 44 45 self.enc_cell = nn.RNN(input_size=n_class, hidden_size=n_hidden, dropout=0.5) 46 self.dec_cell = nn.RNN(input_size=n_class, hidden_size=n_hidden, dropout=0.5) 47 self.fc = nn.Linear(n_hidden, n_class) 48 49 def forward(self, enc_input, enc_hidden, dec_input): 50 enc_input = enc_input.transpose(0, 1) # enc_input: [max_len(=n_step, time step), batch_size, n_class] 51 dec_input = dec_input.transpose(0, 1) # dec_input: [max_len(=n_step, time step), batch_size, n_class] 52 53 # enc_states : [num_layers(=1) * num_directions(=1), batch_size, n_hidden] 54 _, enc_states = self.enc_cell(enc_input, enc_hidden) 55 # outputs : [max_len+1(=6), batch_size, num_directions(=1) * n_hidden(=128)] 56 outputs, _ = self.dec_cell(dec_input, enc_states) 57 58 model = self.fc(outputs) # model : [max_len+1(=6), batch_size, n_class] 59 return model 60 61 62 input_batch, output_batch, target_batch = make_batch(seq_data) 63 64 model = Seq2Seq() 65 criterion = nn.CrossEntropyLoss() 66 optimizer = torch.optim.Adam(model.parameters(), lr=0.001) 67 68 for epoch in range(5000): 69 # make hidden shape [num_layers * num_directions, batch_size, n_hidden] 70 hidden = Variable(torch.zeros(1, batch_size, n_hidden)) 71 72 73 # input_batch : [batch_size, max_len(=n_step, time step), n_class] 74 # output_batch : [batch_size, max_len+1(=n_step, time step) (becase of 'S' or 'E'), n_class] 75 # target_batch : [batch_size, max_len+1(=n_step, time step)], not one-hot 76 output = model(input_batch, hidden, output_batch) 77 # output : [max_len+1, batch_size, n_class] 78 output = output.transpose(0, 1) # [batch_size, max_len+1(=6), n_class] 79 loss = 0 80 for i in range(0, len(target_batch)): 81 # output[i] : [max_len+1, n_class, target_batch[i] : max_len+1] 82 loss += criterion(output[i], target_batch[i]) 83 if (epoch + 1) % 1000 == 0: 84 print('Epoch:', '%04d' % (epoch + 1), 'cost =', '{:.6f}'.format(loss)) 85 86 optimizer.zero_grad() 87 loss.backward() 88 optimizer.step() 89 90 91 # Test 92 def translate(word): 93 input_batch, output_batch, _ = make_batch([[word, 'P' * len(word)]]) 94 95 # make hidden shape [num_layers * num_directions, batch_size, n_hidden] 96 hidden = Variable(torch.zeros(1, 1, n_hidden)) 97 output = model(input_batch, hidden, output_batch) 98 # output : [max_len+1(=6), batch_size(=1), n_class] 99 100 predict = output.data.max(2, keepdim=True)[1] # select n_class dimension 101 decoded = [char_arr[i] for i in predict] 102 end = decoded.index('E') 103 translated = ''.join(decoded[:end]) 104 105 return translated.replace('P', '') 106 107 print('test') 108 print('man ->', translate('man')) 109 print('mans ->', translate('mans')) 110 print('king ->', translate('king')) 111 print('black ->', translate('black')) 112 print('upp ->', translate('upp'))
之后,在seq2seq模型基础上,提出了attention机制。
论文: NEURAL MACHINE TRANSLATION BY JOINTLY LEARNING TO ALIGN AND TRANSLATE
原文地址:https://www.cnblogs.com/dhName/p/11872118.html
- 【知识】SAS数据分析完整笔记(3)
- 深入浅出Redis-Spring整合Redis
- Stream-快速入门Stream编程
- MySQL Regular Expression
- Jenkin-持续集成
- 4.3.4.7 Pattern Matching
- mysql left join、right join、inner join用法分析
- _CrtSetDbgFlag
- UNPv13:#第3章#套接字编程简介
- UNPv13:#第4章#基于TCP套接字编程
- UNPv13:#第5章#TCP客户/服务器程序示例
- MySQL replace用法简介
- YV12转RGB24的计算转换和bmp(dib)文件的显示保存
- 零基础入门深度学习 | 第四章:卷积神经网络
- JavaScript 教程
- JavaScript 编辑工具
- JavaScript 与HTML
- JavaScript 与Java
- JavaScript 数据结构
- JavaScript 基本数据类型
- JavaScript 特殊数据类型
- JavaScript 运算符
- JavaScript typeof 运算符
- JavaScript 表达式
- JavaScript 类型转换
- JavaScript 基本语法
- JavaScript 注释
- Javascript 基本处理流程
- Javascript 选择结构
- Javascript if 语句
- Javascript if 语句的嵌套
- Javascript switch 语句
- Javascript 循环结构
- Javascript 循环结构实例
- Javascript 跳转语句
- Javascript 控制语句总结
- Javascript 函数介绍
- Javascript 函数的定义
- Javascript 函数调用
- Javascript 几种特殊的函数
- JavaScript 内置函数简介
- Javascript eval() 函数
- Javascript isFinite() 函数
- Javascript isNaN() 函数
- parseInt() 与 parseFloat()
- escape() 与 unescape()
- Javascript 字符串介绍
- Javascript length属性
- javascript 字符串函数
- Javascript 日期对象简介
- Javascript 日期对象用途
- Date 对象属性和方法
- Javascript 数组是什么
- Javascript 创建数组
- Javascript 数组赋值与取值
- Javascript 数组属性和方法
- 蒙特卡洛法求积分
- 【赵渝强老师】Weblogic域和域的组成
- 前端生僻字显示
- redis密码操作
- SpringBoot 整合 gradle 集成ActiveMQ
- 利用python自动写docx报告
- Maven构建项目速度太慢的解决办法
- 基于kubernetes Api完成更新镜像版本
- 代码覆盖率是什么?如何查看?
- 3分钟短文 | Laravel 注册全局助手函数的2种方式
- 通过案例学Python之定义函数类
- 3分钟短文 | Laravel blade模板里优雅地定义PHP变量
- MySQL 日期时间类型怎么选?千万不要乱用!
- InnoDB存储引擎简介
- git改错分支的补救方法:git stash暂存