利用RNN和LSTM生成小说题记
时间:2022-05-12
本文章向大家介绍利用RNN和LSTM生成小说题记,主要内容包括其使用实例、应用技巧、基本知识点总结和需要注意事项,具有一定的参考价值,需要的朋友可以参考一下。
一、选取素材
- 本文选取的小说素材来自17k小说网的一篇小说《两只橙与遠太郎》,手工复制小说中的题记。
- 小说网址:http://www.17k.com/list/2793873.html
- 训练语料如下:
小说题记
- 语料格式
题记:此情可待成追忆,只是当时已惘然。
二、开发环境
三、实战代码
#!/bash/bin
# -*-coding:utf-8-*-
import sys
import os
import numpy as np
import collections
import tensorflow as tf
import tensorflow.contrib.rnn as rnn
import tensorflow.contrib.legacy_seq2seq as seq2seq
BEGIN_CHAR = '^'
END_CHAR = '$'
UNKNOWN_CHAR = '*'
MAX_LENGTH = 100
MIN_LENGTH = 10
max_words = 3000
epochs = 50
# 语料
poetry_file = 'story.txt'
# 模型文件存放位置
save_dir = 'model'
class Data:
def __init__(self):
self.batch_size = 64
self.poetry_file = poetry_file
self.load()
self.create_batches()
def load(self):
def handle(line):
if len(line) > MAX_LENGTH:
index_end = line.rfind('。', 0, MAX_LENGTH)
index_end = index_end if index_end > 0 else MAX_LENGTH
line = line[:index_end + 1]
return BEGIN_CHAR + line + END_CHAR
self.poetrys = [line.strip().replace(' ', '').split(':')[1] for line in
open(self.poetry_file, encoding='utf-8')]
self.poetrys = [handle(line) for line in self.poetrys if len(line) > MIN_LENGTH]
# 所有字
words = []
for poetry in self.poetrys:
words += [word for word in poetry]
counter = collections.Counter(words)
count_pairs = sorted(counter.items(), key=lambda x: -x[1])
words, _ = zip(*count_pairs)
# 取出现频率最高的词的数量组成字典,不在字典中的字用'*'代替
words_size = min(max_words, len(words))
self.words = words[:words_size] + (UNKNOWN_CHAR,)
self.words_size = len(self.words)
# 字映射成id
self.char2id_dict = {w: i for i, w in enumerate(self.words)}
self.id2char_dict = {i: w for i, w in enumerate(self.words)}
self.unknow_char = self.char2id_dict.get(UNKNOWN_CHAR)
self.char2id = lambda char: self.char2id_dict.get(char, self.unknow_char)
self.id2char = lambda num: self.id2char_dict.get(num)
self.poetrys = sorted(self.poetrys, key=lambda line: len(line))
self.poetrys_vector = [list(map(self.char2id, poetry)) for poetry in self.poetrys]
def create_batches(self):
self.n_size = len(self.poetrys_vector) // self.batch_size
self.poetrys_vector = self.poetrys_vector[:self.n_size * self.batch_size]
self.x_batches = []
self.y_batches = []
for i in range(self.n_size):
batches = self.poetrys_vector[i * self.batch_size: (i + 1) * self.batch_size]
length = max(map(len, batches))
for row in range(self.batch_size):
if len(batches[row]) < length:
r = length - len(batches[row])
batches[row][len(batches[row]): length] = [self.unknow_char] * r
xdata = np.array(batches)
ydata = np.copy(xdata)
ydata[:, :-1] = xdata[:, 1:]
self.x_batches.append(xdata)
self.y_batches.append(ydata)
class Model:
def __init__(self, data, model='lstm', infer=False):
self.rnn_size = 128
self.n_layers = 2
if infer:
self.batch_size = 1
else:
self.batch_size = data.batch_size
if model == 'rnn':
cell_rnn = rnn.BasicRNNCell
elif model == 'gru':
cell_rnn = rnn.GRUCell
elif model == 'lstm':
cell_rnn = rnn.BasicLSTMCell
cell = cell_rnn(self.rnn_size, state_is_tuple=False)
self.cell = rnn.MultiRNNCell([cell] * self.n_layers, state_is_tuple=False)
self.x_tf = tf.placeholder(tf.int32, [self.batch_size, None])
self.y_tf = tf.placeholder(tf.int32, [self.batch_size, None])
self.initial_state = self.cell.zero_state(self.batch_size, tf.float32)
with tf.variable_scope('rnnlm'):
softmax_w = tf.get_variable("softmax_w", [self.rnn_size, data.words_size])
softmax_b = tf.get_variable("softmax_b", [data.words_size])
with tf.device("/cpu:0"):
embedding = tf.get_variable(
"embedding", [data.words_size, self.rnn_size])
inputs = tf.nn.embedding_lookup(embedding, self.x_tf)
outputs, final_state = tf.nn.dynamic_rnn(
self.cell, inputs, initial_state=self.initial_state, scope='rnnlm')
self.output = tf.reshape(outputs, [-1, self.rnn_size])
self.logits = tf.matmul(self.output, softmax_w) + softmax_b
self.probs = tf.nn.softmax(self.logits)
self.final_state = final_state
pred = tf.reshape(self.y_tf, [-1])
# seq2seq
loss = seq2seq.sequence_loss_by_example([self.logits],
[pred],
[tf.ones_like(pred, dtype=tf.float32)], )
self.cost = tf.reduce_mean(loss)
self.learning_rate = tf.Variable(0.0, trainable=False)
tvars = tf.trainable_variables()
grads, _ = tf.clip_by_global_norm(tf.gradients(self.cost, tvars), 5)
optimizer = tf.train.AdamOptimizer(self.learning_rate)
self.train_op = optimizer.apply_gradients(zip(grads, tvars))
def train(data, model):
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
saver = tf.train.Saver(tf.global_variables())
n = 0
for epoch in range(epochs):
sess.run(tf.assign(model.learning_rate, 0.002 * (0.97 ** epoch)))
pointer = 0
for batche in range(data.n_size):
n += 1
feed_dict = {model.x_tf: data.x_batches[pointer], model.y_tf: data.y_batches[pointer]}
pointer += 1
train_loss, _, _ = sess.run([model.cost, model.final_state, model.train_op], feed_dict=feed_dict)
sys.stdout.write('r')
info = "{}/{} (epoch {}) | train_loss {:.3f}"
.format(epoch * data.n_size + batche,
epochs * data.n_size, epoch, train_loss)
sys.stdout.write(info)
sys.stdout.flush()
# save
if (epoch * data.n_size + batche) % 1000 == 0
or (epoch == epochs - 1 and batche == data.n_size - 1):
checkpoint_path = os.path.join(save_dir, 'model.ckpt')
saver.save(sess, checkpoint_path, global_step=n)
sys.stdout.write('n')
print("model saved to {}".format(checkpoint_path))
sys.stdout.write('n')
def sample(data, model, head=u''):
def to_word(weights):
t = np.cumsum(weights)
s = np.sum(weights)
sa = int(np.searchsorted(t, np.random.rand(1) * s))
return data.id2char(sa)
for word in head:
if word not in data.words:
return u'{} 不在字典中'.format(word)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
saver = tf.train.Saver(tf.global_variables())
model_file = tf.train.latest_checkpoint(save_dir)
saver.restore(sess, model_file)
if head:
print('生成题记 ---> ', head)
poem = BEGIN_CHAR
for head_word in head:
poem += head_word
x = np.array([list(map(data.char2id, poem))])
state = sess.run(model.cell.zero_state(1, tf.float32))
feed_dict = {model.x_tf: x, model.initial_state: state}
[probs, state] = sess.run([model.probs, model.final_state], feed_dict)
word = to_word(probs[-1])
while word != u',' and word != u'。':
poem += word
x = np.zeros((1, 1))
x[0, 0] = data.char2id(word)
[probs, state] = sess.run([model.probs, model.final_state],
{model.x_tf: x, model.initial_state: state})
word = to_word(probs[-1])
poem += word
return poem[1:]
else:
poem = ''
head = BEGIN_CHAR
x = np.array([list(map(data.char2id, head))])
state = sess.run(model.cell.zero_state(1, tf.float32))
feed_dict = {model.x_tf: x, model.initial_state: state}
[probs, state] = sess.run([model.probs, model.final_state], feed_dict)
word = to_word(probs[-1])
while word != END_CHAR:
poem += word
x = np.zeros((1, 1))
x[0, 0] = data.char2id(word)
[probs, state] = sess.run([model.probs, model.final_state],
{model.x_tf: x, model.initial_state: state})
word = to_word(probs[-1])
return poem
if __name__ == '__main__':
# 训练模型
data = Data()
model = Model(data=data, infer=False)
print(train(data, model))
# 生成题记
# data = Data()
# model = Model(data=data, infer=True)
# print(sample(data, model, head='我为秋香'))
输出
生成题记 ---> 我为秋香
我罢性不行,为德劝仙兴。秋风暝冰始,香巢深器酒。
输出
- Java 解析Excel文件为JSON
- SQL语句大小写是否区分的问题,批量修改整个数据库所有表所有字段大小写
- CentOS 6.5 安装nginx 1.6.3
- C#创建数字证书并导出为pfx,并使用pfx进行非对称加解密
- MyBatis两张表字段名相同产生的问题
- mongo 3.0 备份和还原数据库 ,及too many positional arguments错误
- AngularJs HTTP响应拦截器实现登陆、权限校验
- C# 读写App.config配置文件的方法
- Golang语言社区--Go语言基础第四节类型
- Golang语言社区--go语言编写Web程序
- Golang语言社区--Go语言基础第五节流程控制
- (14)不同基因坐标转换-生信菜鸟团博客2周年精选文章集
- (15)基因组各种版本对应关系-生信菜鸟团博客2周年精选文章集
- go 并发处理脚本
- JavaScript 教程
- JavaScript 编辑工具
- JavaScript 与HTML
- JavaScript 与Java
- JavaScript 数据结构
- JavaScript 基本数据类型
- JavaScript 特殊数据类型
- JavaScript 运算符
- JavaScript typeof 运算符
- JavaScript 表达式
- JavaScript 类型转换
- JavaScript 基本语法
- JavaScript 注释
- Javascript 基本处理流程
- Javascript 选择结构
- Javascript if 语句
- Javascript if 语句的嵌套
- Javascript switch 语句
- Javascript 循环结构
- Javascript 循环结构实例
- Javascript 跳转语句
- Javascript 控制语句总结
- Javascript 函数介绍
- Javascript 函数的定义
- Javascript 函数调用
- Javascript 几种特殊的函数
- JavaScript 内置函数简介
- Javascript eval() 函数
- Javascript isFinite() 函数
- Javascript isNaN() 函数
- parseInt() 与 parseFloat()
- escape() 与 unescape()
- Javascript 字符串介绍
- Javascript length属性
- javascript 字符串函数
- Javascript 日期对象简介
- Javascript 日期对象用途
- Date 对象属性和方法
- Javascript 数组是什么
- Javascript 创建数组
- Javascript 数组赋值与取值
- Javascript 数组属性和方法
- Yii框架响应组件用法实例分析
- Android开发学习实现简单计算器
- Android Studio finish()方法的使用与解决app点击“返回”(直接退出)
- Android 8.1隐藏状态栏图标的实例代码
- Android制作登录页面并且记住账号密码功能的实现代码
- Yii框架分页技术实例分析
- PHP命名空间与自动加载机制的基础介绍
- Flutter下Android Studio配置gradle的方法
- Flutter 实现整个App变为灰色的方法示例
- Android studio开发小型对话机器人app(实例代码)
- php中的钩子理解及应用实例分析
- AndroidX下使用Activity和Fragment的变化详解
- PHP Primary script unknown 解决方法总结
- PHP如何将图片文件上传到另外一台服务器上
- android实现滑动解锁