TensorFlow-10-基于 LSTM 建立一个语言模型
今日资料: https://www.tensorflow.org/tutorials/recurrent 中文版: http://wiki.jikexueyuan.com/project/tensorflow-zh/tutorials/recurrent.html 代码: https://github.com/tensorflow/models/blob/master/tutorials/rnn/ptb/ptb_word_lm.py
今天的内容是基于 LSTM 建立一个语言模型
人每次思考时不会从头开始,而是保留之前思考的一些结果,为现在的决策提供支持。RNN 的最大特点是可以利用之前的信息,即模拟一定的记忆,具体可以看我之前写过的这篇文章: 详解循环神经网络(Recurrent Neural Network) http://www.jianshu.com/p/39a99c88a565
RNN 虽然可以处理整个时间序列信息,但是它记忆最深的还是最后输入的一些信号,而之前的信号的强度就会越来越低,起到的作用会比较小。 而 LSTM 可以改善长距离依赖的问题,不需要特别复杂的调试超参数就可以记住长期的信息。关于 LSTM 可以看这一篇文章: 详解 LSTM http://www.jianshu.com/p/dcec3f07d3b5
今天要实现一个语言模型,它是 NLP 中比较重要的一部分,给上文的语境后,可以预测下一个单词出现的概率。
首先下载 ptb 数据集,有一万个不同的单词,有句尾的标记,并且将罕见的词汇统一处理成特殊字符;
$ wget http://www.fit.vutbr.cz/~imikolov/rnnlm/simple-examples.tgz $ tar xvf simple-examples.tgz
PTBInput,
定义语言模型处理输入数据的一些参数,包括 LSTM 的展开步数 num_steps
,用 reader.ptb_producer
读取数据和标签:
PTBModel,
def __init__(self, is_training, config, input_)
包括三个参数,训练标记,配置参数以及输入数据的实例;
把这几个变量读取到本地,hidden_size
是隐藏层的节点数,vocab_size
是词汇表的大小;
def lstm_cell()
,设定基本的 LSTM 单元,用的是 tf.contrib.rnn.BasicLSTMCell
;
如果 if is_training and config.keep_prob < 1
这个条件的话,在 LSTM 单元后面可以加一个 dropout 层;
再用 tf.contrib.rnn.MultiRNNCell
把多层的 LSTM 堆加到一起;
用 cell.zero_state
将 LSTM 的初始状态设置为0;
接下来是 embedding 矩阵,行数是词汇表的大小,列数是每个单词的向量表达的维度,在训练过程中,它可以被优化和更新;
接下来我们要定义输出,限制一下反向传播时可以展开的步数,将 inputs 和 state 传到 LSTM,然后把输出结果添加到 outputs 的列表里;
然后将输出的内容串到一起,接下来 softmax 层,接着要定义损失函数 loss,它的定义形式是这样的:
然后我们要加和整个 batch 的误差,再平均到每个样本的误差,并且保留最终的状态,如果不是训练状态就直接返回;
接下来是定义学习速率,根据前面的 cost 计算一下梯度,并将梯度的最大范数设置好,相当于正则化的作用,可以防止梯度爆炸;
这个学习速率还可以更新,将其传入给 _new_lr
,再执行 _lr_update
完成修改:
接下来可以定义几种不同大小的模型的参数,其中有学习速率,还有梯度的最大范数,还是 LSTM 的层数,反向传播的步数,隐含层节点数,dropout 保留节点的比例,学习速率的衰减速度:
run_epoch
,是定义训练一个 epoch 数据的函数,首先初始化 costs 还有 iters,state;
将 LSTM 的所有 state 加入到 feed_dict
中,然后会生成结果的字典表 fetches,其中会有 cost 和 final_state
;
每完成 10% 的 epoch 就显示一次结果,包括 epoch 的进度,perplexity(是cost 的自然常数指数,这个指标越低,表示预测越好),还有训练速度(单词数每秒):
在 main() 中:
用 reader.ptb_raw_data
读取解压后的数据;
得到 train_data, valid_data, test_data
数据集;
用 PTBInput 和 PTBModel 分别定义用来训练的模型 m,验证的模型 mvalid,测试的模型 mtest;
m.assign_lr
对 m 应用累计的 learning rate;
每个循环内执行一个 epoch 的训练和验证,输出 Learning rate,Train Perplexity, Valid Perplexity。
- 【干货】TensorFlow实战——图像分类神经网络模型
- HTML5手机APP开发入(5)
- 这种自带黑科技的R包,请给我来一打
- 4927 线段树练习5
- codevs4919 线段树练习4
- 利用OpenCV和深度学习实现人脸检测
- 洛谷P2676 超级书架
- 洛谷P1720 月落乌啼算钱
- 2017.10.1解题报告
- 这个包绝对值得你用心体验一次!
- Python之函数的进阶(带参数的装饰器)
- 2017.10.2解题报告
- MVC 5 Scaffolder + EntityFramework+UnitOfWork Pattern 代码生成工具集成Visual Studio 2013
- 左手用R右手Python系列——百度地图API调用与地址解析/逆解析
- JavaScript 教程
- JavaScript 编辑工具
- JavaScript 与HTML
- JavaScript 与Java
- JavaScript 数据结构
- JavaScript 基本数据类型
- JavaScript 特殊数据类型
- JavaScript 运算符
- JavaScript typeof 运算符
- JavaScript 表达式
- JavaScript 类型转换
- JavaScript 基本语法
- JavaScript 注释
- Javascript 基本处理流程
- Javascript 选择结构
- Javascript if 语句
- Javascript if 语句的嵌套
- Javascript switch 语句
- Javascript 循环结构
- Javascript 循环结构实例
- Javascript 跳转语句
- Javascript 控制语句总结
- Javascript 函数介绍
- Javascript 函数的定义
- Javascript 函数调用
- Javascript 几种特殊的函数
- JavaScript 内置函数简介
- Javascript eval() 函数
- Javascript isFinite() 函数
- Javascript isNaN() 函数
- parseInt() 与 parseFloat()
- escape() 与 unescape()
- Javascript 字符串介绍
- Javascript length属性
- javascript 字符串函数
- Javascript 日期对象简介
- Javascript 日期对象用途
- Date 对象属性和方法
- Javascript 数组是什么
- Javascript 创建数组
- Javascript 数组赋值与取值
- Javascript 数组属性和方法
- Go语言(十 八)context&日志项目
- 使用梯度上升欺骗神经网络,让网络进行错误的分类
- Go语言(十七) 配置文件库项目
- Python 相对路径问题:“No such file or directory“
- 基于etcd服务发现的overlay跨多宿主机容器网络
- Go语言(十六) 日志项目升级
- PyQt5 技术篇-设置窗口相对桌面位置,按屏幕比例
- Go语言(十五) 反射
- SpringBoot应用跨域访问解决方案
- Spring Boot 2.2都有哪些新变化
- Go语言(十四)日志项目
- 如何在Spring Boot中使用Cookies
- 在SpringBoot中使用flyway管理数据库版本状态
- 使用Spring Data JPA进行数据分页与排序
- 搭建一个高可用负载均衡的集群架构(第二部分)