使用CNN+ Auto-Encoder 实现无监督Sentence Embedding (代码基于Tensorflow)
1前言
这篇文章会利用到上一篇: 基于Spark /Tensorflow使用CNN处理NLP的尝试的数据预处理部分,也就是如何将任意一段长度的话表征为一个2维数组。
本文完整的代码在这: autoencoder-sentence-similarity.py(https://gist.github.com/allwefantasy/51275cb5c649e4a69b33131e967e2af9#file-autoencoder-sentence-similarity-py)
基本思路是,通过编码解码网络(有点类似微软之前提出的对偶学习),先对句子进行编码,然后进行解码,解码后的语句要和原来的句子尽可能的接近。训练完成后,我们就可以将任意一个句子进行编码为一个向量,这算是Sentence Embedding的一种新的实现。最大的好处是,整个过程无需标注语料,属于无监督类学习。这次我还在编码前引入卷积网络,不过效果有待验证。
2准备工作
我们假设大家已经准备了两个数据集,具体可以参考上一篇文章的Spark预处理部分:
- 已经分好词的语句
- 词到vector的字典
然后我们使用如下函数来进行处理:
def next_batch(batch_num, batch_size, word_vec_dict):
with open(WORD_FILE) as wf:
line_num = 0
start_line_num = batch_num * batch_size
batch_counter = 0
result = []
for words in wf:
result1 = []
line_num += 1
if line_num > start_line_num:
batch_counter += 1
for word in words.split(" "):
if word in word_vec_dict:
result1.append(word_vec_dict[word])
if len(result1) < SEQUENCE_LENGTH:
for i in range(SEQUENCE_LENGTH - len(result1)): result1.append(np.zeros(shape=(VOCAB_SIZE, 1)).tolist())
result.append([str(line_num), result1[0:SEQUENCE_LENGTH]])
if batch_counter == batch_size:
return result
字典的格式是: word + 空格 + 100个逗号分隔的数字
文本内容格式是: 通过空格分隔的已经分好词的句子
因为这次测试数据集有点大,所以没办法一次性载入到内存,只能分批了。缺点是,每一次都需要重新打开文件,为了减少打开文件次数,程序后半部分做了一些优化处理,基本方式为,一次性从文件里取batch_size 条数据,然后让Tensorflow 再分 batch_size / mini_train_batch_size 次进行迭代。每次迭代给的样本量还是比较影响效果的,4000和200比,有20%左右的差异。
3构建模型
我尝试了如下两个拓扑,第一个是带卷积的:
Input(段落) -> 卷积 -> 池化 -> 卷积 -> 池化 -> encoder -> encoder -> decoder -> decoder -> lost function (consine夹角)
第二个则是不带卷积:
Input(段落) -> encoder -> encoder -> decoder -> decoder -> lost function (consine夹角)
基本上是两层卷积,两层编解码器。
训练完成后,就获得编码器的所有参数,利用encoder_op 对所有的语句进行编码,从而实现所有语句得到一个唯一的向量(128维)表示。
大概准备了 60多万条语句进行训练,经历了50*60轮迭代后,不带卷积最后loss 大概是从1.1 下降到0.94的样子。如果进行更多迭代,提供更多训练数据,应该可以进一步降低。
带卷积的收敛较快,loss 从0.5 经过3000轮迭代,可以下降到0.1 左右。
因为语料比较隐私,无法提供,但是可以描述下大致的结论,随机找一段话,然后寻找相似的,目前来看,不带卷积的效果非常好,带卷积的因为卷积后信息损失太大,在encoder-decoder阶段感觉无法训练了,最后趋同,因此需要对卷积进行较大调整。关于NLP的卷积,其实我们不一定要保证卷积核的宽度或者高度保持和word embedding的size一样,因为对两个word截取其一般,我认为应该还是有一定的相似度的。
在训练过程中,损失函数我尝试了:
xent =tf.reduce_mean(tf.pow([y_true, y_pred],2), name="xent")
以及
consine = tf.div(tf.reduce_sum(tf.multiply(y_pred, y_true)),
tf.multiply(tf.sqrt(tf.reduce_sum(tf.multiply(y_pred, y_pred))), tf.sqrt(tf.reduce_sum(tf.multiply(y_true, y_true))))) xent = tf.reduce_sum(tf.subtract(tf.constant(1.0), consine))
因为采用欧式距离,我们难以确定相似度的阈值,而cosine是一个比较容易衡量的值。所以这里选择了consine作为损失函数。我没有找到Tensorflow的实现,所以完全根据consine公式自己实现了一个。
对于Optimizer,我尝试了
train_step = tf.train.GradientDescentOptimizer(learning_rate).minimize(xent) train_step = tf.train.RMSPropOptimizer(learning_rate).minimize(xent)
目前来看,RMSPropOptimizer效果好很多。
4总结
现阶段大量优秀的人才都投入到了深度学习上,所以深度学习也取得了越来越多的进展,用法也越来越多,尤其是对抗学习,加强学习,对偶学习的发展让深度学习可以做的事情越来越多。深度学习在NLP文本分类,特征抽取方面,我觉得还是有潜力可挖的。不过,我觉得深度学习其实是把机器学习的门槛提的更高了,虽然越来越多的工具(比如Tensorflow)和已知的各种实践似乎在降低某些门槛。
- 万达网科裁员95% 王健林曾宣布要在2020年整体上市
- Linq学习笔记(三)
- Go语言cmd命令通过管道实现交互
- 三撩Python
- linq学习笔记(二)
- 盘点2017十大科学突破,让孩子与未来相遇
- ASP.NET 2.0 中的异步页[来自MSDN]
- 温习sql语句中JOIN的各种操作(SQL2005环境)
- 揭开ps的神秘面纱——初步认识photoshop
- 地理坐标系与投影坐标系的区别
- ExtJs学习笔记(6)_可分页的GridPanel
- PowerDesinger联系的定义及使用
- Gis链接
- TortoiseSVN文件夹及文件图标不显示解决方法 TortoiseSVN文件夹及文件图标不显示解决方法
- JavaScript 教程
- JavaScript 编辑工具
- JavaScript 与HTML
- JavaScript 与Java
- JavaScript 数据结构
- JavaScript 基本数据类型
- JavaScript 特殊数据类型
- JavaScript 运算符
- JavaScript typeof 运算符
- JavaScript 表达式
- JavaScript 类型转换
- JavaScript 基本语法
- JavaScript 注释
- Javascript 基本处理流程
- Javascript 选择结构
- Javascript if 语句
- Javascript if 语句的嵌套
- Javascript switch 语句
- Javascript 循环结构
- Javascript 循环结构实例
- Javascript 跳转语句
- Javascript 控制语句总结
- Javascript 函数介绍
- Javascript 函数的定义
- Javascript 函数调用
- Javascript 几种特殊的函数
- JavaScript 内置函数简介
- Javascript eval() 函数
- Javascript isFinite() 函数
- Javascript isNaN() 函数
- parseInt() 与 parseFloat()
- escape() 与 unescape()
- Javascript 字符串介绍
- Javascript length属性
- javascript 字符串函数
- Javascript 日期对象简介
- Javascript 日期对象用途
- Date 对象属性和方法
- Javascript 数组是什么
- Javascript 创建数组
- Javascript 数组赋值与取值
- Javascript 数组属性和方法
- 打卡群刷题总结0922——丑数 II
- 打卡群刷题总结0923——完全平方数
- 打卡群刷题总结0924——最长上升子序列
- VS2017中使用QT Chart图表
- C++核心准则T.81:不要混用继承层级和数组
- C++核心准则T.83:不要将成员函数定义为模板虚函数
- C++核心准则T.84:使用非模板核心实现提供稳定的ABI接口
- C++核心准则T.120:只在确实有需要时使用模板元编程
- C++核心准则T.121:模板元编程主要用于模仿概念
- C++核心准则T.122:使用模板在编译时计算类型
- C++核心准则T.123:使用常量表达式函数在编译时求值
- Java基础 【类之间的关系】
- MySql 学习之路-基础
- (有趣的)项目实战:Java实现计算机自动关机
- 猜生日 Java小游戏