Transformer中的维度变换
input: batch_size * max_sen_len
vocab_matrix dim: vocab_size * embedding_dim
PE(pos,2i)=sin(pos/10000^(2i/embedding_dim))
PE(pos,2i+1)=cos(pos/10000^(2i/embedding_dim))
encoder input embedding x = input token emb + position emb : batch_size * max_sen_len * embedding_dim
对每一句话(句尾</s>):[ max_sen_len * embedding_dim ]
ENCODER:
input -> dropout ->
(multihead SAN -> dropout -> residual connection -> LN -> FFN -> dropout -> RS connection-> LN) * 6 ->
[batch_size, max_sen_len, embedding_dim]
---- multihead self atten ----
WQ,WK,WV: embedding_dim * embedding_dim,
其中WQ,WK,WV可以切分为多头WQ_i,Wk_i,WV_i,即第二个维度 = embedding_dim/num_heads=d_k
WQ_i,Wk_i,WV_i: embedding_dim * d_k
q_i,k_i,v_i = x * WQ_i,WK_i,WV_i : max_sen_len * d_k
multihead:
q_i * k_i / srqt(d_k) : max_sen_len * max_sen_len
softmax之前要对q和k做mask,把pad 0的维度置为-inf,这样softmax之后对应位置权重为0
softmax(q_i * k_i / srqt(d_k) + Mask) * v_i = head_i, 在最后一个维度上做softmax
head_i: max_sen_len * d_k
Multi_head = cat num_heads of head_i = [head_1,head_2,...,head_8]: max_sen_len * embedding_dim
---- add & norm ----
max_sen_len * embedding_dim
----ffn & add & norm ----
ffn = Relu(W_1 * x + b_1) * W_2 +b_2
Relu = max(0,x)
W_1 : embedding_dim * ffn_hidden_size
b_1 : 1 * ffn_hidden_size
W_2 : ffn_hidden_size * embedding_dim
b_2 : 1 * embedding_dim
---- enc out ----
[batch_size, max_sen_len, embedding_dim]
DECODER:
decoder input -> droput ->
(masked multihead self atten -> dropout -> RS connection-> LN ->
multihead self atten -> dropout -> RS connection-> LN ->
FFN -> dropout -> RS connection-> LN) *6 ->
[batch_size, max_sen_len, vocab_size]
decoder input embedding y = input token emb + position emb : batch_size * max_sen_len, embedding_dim
对每一句话y(要添加起始符号<s>) : [ max_sen_len * embedding_dim ]
ENCODER的输出给每一层DECODER
---- masked multihead self atten ----
上三角矩阵置为-inf
q,k 来自encoder输出:max_sen_len, embedding_dim
q_i,k_i,v_i = y * WQ_i,WK_i,WV_i : max_sen_len * d_k
multihead:
q_i * k_i / srqt(d_k) : max_sen_len * max_sen_len
softmax(q_i * k_i / srqt(d_k) + Mask) * v_i = head_i : max_sen_len * d_k
Multi_head = cat num_heads of head_i = [head_1,head_2,...,head_8]: max_sen_len * embedding_dim
---- multihead self att ----
维度变换同上
Multi_head : max_sen_len * embedding_dim
---- add & norm ----
max_sen_len * embedding_dim
----ffn & add & norm ----
ffn = Relu(W_1 * y + b_1) * W_2 +b_2
Relu = max(0,y)
W_1 : embedding_dim * ffn_hidden_size
b_1 : 1 * ffn_hidden_size
W_2 : ffn_hidden_size * vocab_size
b_2 : 1 * vocab_size
[batch_size, max_sen_len, vocab_size]
---- dec out ----
[batch_size, max_sen_len, vocab_size]
在最后一维做softmax:vocab_size,得到词典哭上的概率分布,输出最大的概率,与真实标签进行交叉熵损失的计算,汇总一句话中每个的损失,优化,训练
原文地址:https://www.cnblogs.com/yh-blog/p/15115253.html
- Effective Modern C++翻译(7)-条款6:当auto推导出意外的类型时,使用显式的类型初始化语义
- 2.3 ls命令
- Effective Modern C++翻译(6)-条款5:auto比显示的类型声明要更好
- 大白话-prototype属性
- Effective Modern C++翻译(5)-条款4:了解如何观察推导出的类型
- Effective Modern C++翻译(4)-条款3:了解decltype
- 大白话-constructor
- Effective Modern C++翻译(3)-条款2:明白auto类型推导
- React Native在Android平台运行gif的解决方法
- Effective Modern C++翻译(2)-条款1:明白模板类型推导
- Android ormLite复杂条件查询
- Effective Modern C++翻译(1):序言
- C++操作mysql方法总结(2)
- Linux基础(day3)
- JavaScript 教程
- JavaScript 编辑工具
- JavaScript 与HTML
- JavaScript 与Java
- JavaScript 数据结构
- JavaScript 基本数据类型
- JavaScript 特殊数据类型
- JavaScript 运算符
- JavaScript typeof 运算符
- JavaScript 表达式
- JavaScript 类型转换
- JavaScript 基本语法
- JavaScript 注释
- Javascript 基本处理流程
- Javascript 选择结构
- Javascript if 语句
- Javascript if 语句的嵌套
- Javascript switch 语句
- Javascript 循环结构
- Javascript 循环结构实例
- Javascript 跳转语句
- Javascript 控制语句总结
- Javascript 函数介绍
- Javascript 函数的定义
- Javascript 函数调用
- Javascript 几种特殊的函数
- JavaScript 内置函数简介
- Javascript eval() 函数
- Javascript isFinite() 函数
- Javascript isNaN() 函数
- parseInt() 与 parseFloat()
- escape() 与 unescape()
- Javascript 字符串介绍
- Javascript length属性
- javascript 字符串函数
- Javascript 日期对象简介
- Javascript 日期对象用途
- Date 对象属性和方法
- Javascript 数组是什么
- Javascript 创建数组
- Javascript 数组赋值与取值
- Javascript 数组属性和方法
- 小书MybatisPlus第3篇-自定义SQL
- Nginx + Spring Boot 实现负载均衡
- 小书MybatisPlus第2篇-条件构造器的应用及总结
- 一个案例演示 Spring Security 中粒度超细的权限控制!
- 信息收集之主机发现:nmap
- 文本文件逐行处理–用java8 Stream流的方式
- 使用java8API遍历过滤文件目录及子目录及隐藏文件
- 使用位运算、值交换等方式反转java字符串-共四种方法
- 精讲RestTemplate第2篇-多种底层HTTP客户端类库的切换
- 精讲RestTemplate第1篇-在Spring或非Spring环境下如何使用
- 在图中添加多边形
- 设置坐标轴刻度的位置和样式
- OkHttp透明压缩,收获性能10倍,外加故障一枚
- 体验spring-boot-devtools热部署,流畅且不失强大
- 【每周一库】 simsearch - a simple and lightweight fuzzy search engine