理工男的文科梦 —— LSTM深度学习写春联
笔者作为一名根正苗红的理工男,内心却常常有很多文艺青年才会有的想法,例如写首诗、做首词,甚至包括春节写副对联,空有一番愿望却胸无点墨。随着对机器学习和深度学习的了解,逐渐萌生了使用机器帮助笔者完成文艺青年的转型。
本文借助递归神经网络RDD的变种之一LSTM算法,对收集到的6900多条对联进行学习,训练好模型后可以由机器写出对联。
递归神经网络与LSTM
故事从人工神经网络开始,人工神经网络诞生已久。如下图所示,神经网络的基本结构由输入层、输出层和一个或多个隐含层组成。
全连接的神经网络下一层神经元的输入由上一层所有神经元的输出决定,因此带来了一个严重的问题即参数数量过大导致无法训练。因此,随时神经网络的发展,衍生了一系列的变化。比较流行的有应用于图像识别领域的卷积神经网络CNN、应用于自然语言处理的递归神经网络RNN。本文应用到的LSTM算法即为RNN的一种形态。RNN解决了这样的问题:即样本出现的时间顺序对于自然语言处理、语音识别、手写体识别等应用非常重要,神经元的输出可以在下一个时间戳直接作用到自身。因此RNN很适合处理时序对结果影响较深的领域。
关于RNN和LSTM原理的说明可以移步 http://www.jianshu.com/p/9dc9f41f0b29 ,本文不多加赘言。
由LSTM作诗引发
由于LSTM算法非常适用自然语言处理领域,因此网上出现了很多应用LSTM做文字领域的尝试,例如: LSTM写诗 中使用LSTM写诗,LSTM创作歌词中使用LSTM模仿歌手风格写歌词,以及使用LSTM算法给小孩起名(是多么不靠谱的粑粑麻麻)。
因此,笔者突发想法,如果给一个足够的春联训练样本,一样可以照猫画老虎,训练一个可以写对联的文艺“机器模型”。因此,问题就分解为:找样本、写算法、训练、应用模型。
春联样本搜集和规整
借助于强大的度娘,费劲九牛之力,从网上搜集了各式春联共6900对,其中上联下联之间是用","分割区分上下联,对联之间是用"。"区分一联的结束。样式如下:
这些样本将会在训练阶段进行类型转换并输入给LSTM模型中。如果您也想试下本文案例,请私信我这些样本(毕竟搜集训练样本是个苦差事(: )
LSTM算法
本文使用TensorFlow进行建模,TensorFlow就无需多言,是这个领域目前最活跃的框架。写对联的算法主要工作包括:根据样本数据产生LSTM输入数据和结果值;定义LSTM的模型以及损失函数;将训练数据喂给TensorFlow用来训练模型。接下来会逐步列举本例中使用的方法。
- 训练数据转换
由于样本数据是一条条汉字组成的对联,这样的数据是无法交给模型训练的,因此需要对样本数据进行转换。基本思想是:
- 将样本的所有对联加载录入,统计出所有出现的汉字,并将汉字进行编码,例如:一共有10000个汉字出现在样本中,那么对出现的汉字按 0 - 999 进行编码,每个汉字对应一个编码。
- 对原始样本进行编码转换,生成用数字编码表示的对联集。
- 每条对联作为一个输入序列,每批次训练batch_size条,生成输入数据xdata,输出y值为xdata+1。因为文本分析的特点是有时序性。
couplet_file ="couplet.txt"#对联couplets = []with open(couplet_file,'r') as f: for line in f: try:
content = line.replace(' ','') if '_' in content or '(' in content or '(' in content or '《' in content or '[' in content: continue
if len(content) < 5*3 or len(content) > 79*3: continue
content = '[' + content + ']'
# print chardet.detect(content)
content = content.decode('utf-8')
couplets.append(content) except Exception as e: pass# 按字数排序couplets = sorted(couplets,key=lambda line: len(line))
print('对联总数: %d'%(len(couplets)))# 统计每个字出现次数all_words = []for couplet in couplets:
all_words += [word for word in couplet]
counter = collections.Counter(all_words)
count_pairs = sorted(counter.items(), key=lambda x: -x[1])
words, _ = zip(*count_pairs)
words = words[:len(words)] + (' ',)# 每个字映射为一个数字IDword_num_map = dict(zip(words, range(len(words))))
to_num = lambda word: word_num_map.get(word, len(words))
couplets_vector = [ list(map(to_num, couplet)) for couplet in couplets]# 每次取64首对联进行训练, 此参数可以调整batch_size = 64n_chunk = len(couplets_vector) // batch_size
x_batches = []
y_batches = []for i in range(n_chunk):
start_index = i * batch_size#起始位置
end_index = start_index + batch_size#结束位置
batches = couplets_vector[start_index:end_index]
length = max(map(len,batches))#每个batches中句子的最大长度
xdata = np.full((batch_size,length), word_num_map[' '], np.int32) for row in range(batch_size):
xdata[row,:len(batches[row])] = batches[row]
ydata = np.copy(xdata)
ydata[:,:-1] = xdata[:,1:]
x_batches.append(xdata)
y_batches.append(ydata
定义LSTM模型
- 使用TF api tf.nn.rnn_cell.BasicLSTMCell定义cell为一个128维的ht的cell。并使用MultiRNNCell 定义为两层的LSTM。
- 对训练样本输入进行embedding化。
- 使用tf.nn.dynamic_rnn计算输出值。(也可以通过循环step的方法,依次计算)
- 加入softmax层。
def neural_network(rnn_size=128, num_layers=2):
cell = tf.nn.rnn_cell.BasicLSTMCell(rnn_size, state_is_tuple=True)
cell = tf.nn.rnn_cell.MultiRNNCell([cell] * num_layers, state_is_tuple=True)
initial_state = cell.zero_state(batch_size, tf.float32) with tf.variable_scope('rnnlm'):
softmax_w = tf.get_variable("softmax_w", [rnn_size, len(words)+1])
softmax_b = tf.get_variable("softmax_b", [len(words)+1]) with tf.device("/cpu:0"):
embedding = tf.get_variable("embedding", [len(words)+1, rnn_size])
inputs = tf.nn.embedding_lookup(embedding, input_data)
outputs, last_state = tf.nn.dynamic_rnn(cell, inputs, initial_state=initial_state, scope='rnnlm')
output = tf.reshape(outputs,[-1, rnn_size])
logits = tf.matmul(output, softmax_w) + softmax_b
probs = tf.nn.softmax(logits) return logits, last_state, probs, cell, initial_state
训练阶段
- 使用TF sequence_loss_by_example计算所有examples(假设一句话有n个单词,一个单词及单词所对应的label就是一个example,所有example就是一句话中所有单词)的加权交叉熵损失。
- tf.gradients 计算梯度,并使用clip_by_global_norm控制梯度爆炸的问题。梯度爆炸和梯度弥散的原因一样,都是因为链式法则求导的关系,导致梯度的指数级衰减。为了避免梯度爆炸,需要对梯度进行修剪。(来自网上的解释,不明觉厉(: )
- 定义步长,步长过大,会很可能越过最优值,步长过小则使优化的效率过低,长时间无法收敛。因此learning rate是一个需要适当调整的参数。一个小技巧是,随时训练的进行,即沿着梯度方向收敛的过程中,适当减小步长,不至于错过最优解。在代码中 0.01 * (0.97 ** epoch),learing rate基数值为0.01, 系数为0.97的epoch方,可以看出epoch越大,learing rate越小。
- 分批次将样本数据x_batches和y_batches喂给TF进行训练。
def train_neural_network():
logits, last_state, _, _, _ = neural_network()
targets = tf.reshape(output_targets, [-1])
loss = tf.contrib.legacy_seq2seq.sequence_loss_by_example([logits], [targets], [tf.ones_like(targets, dtype=tf.float32)], len(words))
cost = tf.reduce_mean(loss)
learning_rate = tf.Variable(0.0, trainable=False)
tvars = tf.trainable_variables()
grads, _ = tf.clip_by_global_norm(tf.gradients(cost, tvars), 5)
optimizer = tf.train.AdamOptimizer(learning_rate)
train_op = optimizer.apply_gradients(zip(grads, tvars)) with tf.Session() as sess:
sess.run(tf.initialize_all_variables())
saver = tf.train.Saver(tf.all_variables()) for epoch in range(100):
sess.run(tf.assign(learning_rate, 0.01 * (0.97 ** epoch)))
n = 0
for batche in range(n_chunk):
train_loss, _ , _ = sess.run([cost, last_state, train_op], feed_dict={input_data: x_batches[n], output_targets: y_batches[n]})
n += 1
print(epoch, batche, train_loss) if epoch % 7 == 0:
saver.save(sess, './couplet.module', global_step=epoch)
训练结束 , 诗性大发
经过漫长的训练(取决于样本数和迭代次数), loss控制在1.5左右。
可以看到,经过100次的迭代训练,每7次保存一次(saver.save(sess, './couplet.module', global_step=epoch)), 最后的模型保存在couplet.module-98里。
在eval阶段,使用saver.restore(sess, 'couplet.module-98') 将训练好的模型加载, 因为机器算出来的依旧是上文提到的数字编码,因此需要再将数字转为汉字。
好啦,来看看机器创作的对联吧, 是不是有点意思呢?
- Java8新特性——StreamAPI(二)
- 从WordPress SQLi谈PHP格式化字符串问题
- 自己实现PC端jQuery版轮播图
- Tips for thrift
- Java8新特性——StreamAPI(一)
- springcloud学习手册-Hystrix(仪表盘说明)
- RedRabbit——基于BrokerPattern服务器框架
- C++任务队列与多线程
- 游戏服务器ID生成器组件
- Java8新特性——Lambda表达式(一)
- C++ FFLIB之FFRPC:多线程&多进程的scalability探索
- ffrpc-c++进程间(服务器端、客户端)通信框架
- Docker入门实战(一)——Docker常用命令
- C++使用ffpython嵌入和扩展python
- JavaScript 教程
- JavaScript 编辑工具
- JavaScript 与HTML
- JavaScript 与Java
- JavaScript 数据结构
- JavaScript 基本数据类型
- JavaScript 特殊数据类型
- JavaScript 运算符
- JavaScript typeof 运算符
- JavaScript 表达式
- JavaScript 类型转换
- JavaScript 基本语法
- JavaScript 注释
- Javascript 基本处理流程
- Javascript 选择结构
- Javascript if 语句
- Javascript if 语句的嵌套
- Javascript switch 语句
- Javascript 循环结构
- Javascript 循环结构实例
- Javascript 跳转语句
- Javascript 控制语句总结
- Javascript 函数介绍
- Javascript 函数的定义
- Javascript 函数调用
- Javascript 几种特殊的函数
- JavaScript 内置函数简介
- Javascript eval() 函数
- Javascript isFinite() 函数
- Javascript isNaN() 函数
- parseInt() 与 parseFloat()
- escape() 与 unescape()
- Javascript 字符串介绍
- Javascript length属性
- javascript 字符串函数
- Javascript 日期对象简介
- Javascript 日期对象用途
- Date 对象属性和方法
- Javascript 数组是什么
- Javascript 创建数组
- Javascript 数组赋值与取值
- Javascript 数组属性和方法
- Chrome 插件开发-主动获取所有页签的tabid,background.js直接向所有页签同步传递消息
- Python 技术篇-httpClient库连接服务器发送请求解析响应实例演示,No module named ‘httplib‘问题解决方法
- Windows 技术篇-windows家庭版win+R运行搜不到gpedit.msc组策略编辑器解决方法,windows安装组策略编辑器方法
- SpringBoot中Tomcat是如何启动的
- JavaScript 技术篇-chrome浏览器读取剪切板命令document.execCommand(‘paste‘)返回false原因及解决方法
- Linux 技术篇-文件大小查看方法实例演示,查看指定文件大小,查看列表下所有文件夹和文件的大小
- Python 技术篇-判断指定路径下的文件是否处于打开状态或占用状态实例演示
- Python 爬虫进阶篇-利用beautifulsoup库爬取网页文章内容实战演示
- Windows 技术篇-cmd强制关闭端口、解除端口占用方法,cmd查询端口相关的进程pid并杀死进程实例演示
- PL/SQL 数据库连接工具的下载、安装与使用实例演示
- JavaScript 技术篇-JSON字符串在线快速格式化查看实例演示,json.cn网址格式化json字符串
- Python 技术篇-含中文编码的代码运行方法,(unicode error) ‘utf-8‘ codec can‘t decode问题原因及解决方法
- Spring Boot Actuators
- 用图机器学习探索 A 股个股相关性变化
- Python 技术篇-连接oracle数据库并执行sql语句实例演示,python连接oracle数据库oci详细配置方法