pytorch中的nn.Embedding
时间:2022-07-23
本文章向大家介绍pytorch中的nn.Embedding,主要内容包括其使用实例、应用技巧、基本知识点总结和需要注意事项,具有一定的参考价值,需要的朋友可以参考一下。
直接看代码:
import torch
import torch.nn as nn
embedding=nn.Embedding(10,3)
input=torch.LongTensor([[1,2,4,5],[4,3,2,9]])
embedding(input)
tensor([[[ 0.8052, -0.1044, -0.6971],
[ 1.3792, -0.1265, -1.1444],
[ 1.4152, -0.1551, -1.2433],
[ 0.7060, -1.0585, 0.5130]],
[[ 1.4152, -0.1551, -1.2433],
[-0.9881, -0.1601, 0.6339],
[ 1.3792, -0.1265, -1.1444],
[-1.1703, 1.8496, 0.8113]]], grad_fn=<EmbeddingBackward>)
第一个参数是字的总数,第二个参数是字的向量表示的维度。
我们的输入input是两个句子,每个句子都是由四个字组成的,使用每个字的索引来表示,于是使用nn.Embedding对输入进行编码,每个字都会编码成长度为3的向量。
再看看下个例子:
embedding = nn.Embedding(10, 3, padding_idx=0)
input=torch.LongTensor([[0,2,0,5]])
embedding(input)
tensor([[[ 0.0000, 0.0000, 0.0000],
[ 0.0829, 1.4141, 0.0277],
[ 0.0000, 0.0000, 0.0000],
[ 0.1337, -1.1472, 0.2182]]], grad_fn=<EmbeddingBackward>)
transformer中的字的编码就可以这么表示:
class Embeddings(nn.Module):
def __init__(self,d_model,vocab):
#d_model=512, vocab=当前语言的词表大小
super(Embeddings,self).__init__()
self.lut=nn.Embedding(vocab,d_model)
# one-hot转词嵌入,这里有一个待训练的矩阵E,大小是vocab*d_model
self.d_model=d_model # 512
def forward(self,x):
# x ~ (batch.size, sequence.length, one-hot),
#one-hot大小=vocab,当前语言的词表大小
return self.lut(x)*math.sqrt(self.d_model)
# 得到的10*512词嵌入矩阵,主动乘以sqrt(512)=22.6,
#这里我做了一些对比,感觉这个乘以sqrt(512)没啥用… 求反驳。
#这里的输出的tensor大小类似于(batch.size, sequence.length, 512)
参考:
https://zhuanlan.zhihu.com/p/107889011
https://blog.csdn.net/qq_38883844/article/details/104331382
- JavaScript 教程
- JavaScript 编辑工具
- JavaScript 与HTML
- JavaScript 与Java
- JavaScript 数据结构
- JavaScript 基本数据类型
- JavaScript 特殊数据类型
- JavaScript 运算符
- JavaScript typeof 运算符
- JavaScript 表达式
- JavaScript 类型转换
- JavaScript 基本语法
- JavaScript 注释
- Javascript 基本处理流程
- Javascript 选择结构
- Javascript if 语句
- Javascript if 语句的嵌套
- Javascript switch 语句
- Javascript 循环结构
- Javascript 循环结构实例
- Javascript 跳转语句
- Javascript 控制语句总结
- Javascript 函数介绍
- Javascript 函数的定义
- Javascript 函数调用
- Javascript 几种特殊的函数
- JavaScript 内置函数简介
- Javascript eval() 函数
- Javascript isFinite() 函数
- Javascript isNaN() 函数
- parseInt() 与 parseFloat()
- escape() 与 unescape()
- Javascript 字符串介绍
- Javascript length属性
- javascript 字符串函数
- Javascript 日期对象简介
- Javascript 日期对象用途
- Date 对象属性和方法
- Javascript 数组是什么
- Javascript 创建数组
- Javascript 数组赋值与取值
- Javascript 数组属性和方法