Tensorflow2.0实现简单的RNN文本分析
前面我们介绍的全连接神经网络以及卷积神经网络都只能单独处理一个个输入,并且前一个输入和后一个输入往往是没有直接联系。但是,在某些情况下我们需要很好地处理序列信息,即前一个输入与后一个输入是有关系的。比如我们理解一句话的时候,往往需要联系前后的句子才能得到这句话表达的准确含义。序列问题有很多,例如语音对话、文本理解以及视频/音频分析等。今天老shi将给大家介绍深度学习中另外一种非常重要的神经网络类型——循环神经网络RNN,它最擅长处理序列问题!
举个栗子,比如,老师说小明总是上课迟到,今天罚____打扫卫生。很明显,这里空缺的部分大概率是说小明,而不是我。
那么,循环神经网络到底是啥?循环神经网络(Recurrent Neural Network)种类繁多,我们先从最简单的循环神经网络开始吧。
基本的循环神经网络
下图是一个简单的循环神经网络,它由输入层、一个隐藏层和一个输出层组成:
这个图看起来有点奇怪,跟我们之前介绍的神经网络都不太一样。这是因为循环神经网络实在是不好画,这是简化+抽象后的画法。如果把上面有W的那个带箭头的圈去掉,它就变成了最普通的全连接神经网络。x是一个向量,它表示输入层的值(这里神经元节点没有画出来);s是一个向量,它表示隐藏层的值(这里隐藏层面只画了一个节点,你也可以想象这一层其实是有多个节点,节点数与向量s的维度相同);U是输入层到隐藏层的权重矩阵(类似于全连接神经网络中每层的权重);o也是一个向量,它表示输出层的值;V是隐藏层到输出层的权重矩阵。现在我们来看看W到底是什么?因为循环神经网络的隐藏层的值s不仅仅取决于当前这一次的输入x,还取决于上一次隐藏层的值s。所以,权重矩阵W就是隐藏层上一次的值作为这一次的输入的权重。
如果我们把上面的图展开,循环神经网络大概就是下面这个样子:
现在看上去就比较清楚了,这个网络在t时刻接收到输入Xt之后,隐藏层的值是St,输出值是Ot。关键一点是,St的值不仅仅取决于Xt,还取决于St-1。我们可以用下面的公式来表示循环神经网络的计算方法:
式1是输出层的计算公式,输出层是一个全连接层,也就是它的每个节点都和隐藏层的每个节点相连。V是输出层的权重矩阵,g是激活函数。式2是隐藏层的计算公式,它是循环层。U是输入x的权重矩阵,W是上一次的值作为这一次的输入的权重矩阵,f是激活函数。
从上面的公式我们可以看出,循环层和全连接层的区别就是循环层多了一个权重矩阵 W。如果反复把式2代入到式1,我们将得到:
从上面可以看出,循环神经网络的输出值,是受前面很多次输入值影响的,这就是为什么循环神经网络可以往前看任意多个输入值的原因。
当然,前面介绍的只是最基本的循环神经网络结构,除此之外,其实循环神经网络还有例如双向循环神经网络、深度循环神经网络以及很多其他的变种,这里老shi不打算一次给大家介绍完(介绍了你们也接受不了,哈哈)。最后是一个非常简单的文本分析RNN代码实践案例,有兴趣的同学可以跟着现实一下。下节课老shi准备给大家介绍非常常用的RNN变种LSTM和GRU,敬请期待!!
from tensorflow import kerasfrom tensorflow.keras import layers
num_words = 30000maxlen = 200
#导入数据(x_train, y_train), (x_test, y_test) = keras.datasets.imdb.load_data(num_words=num_words)
#x_train = keras.preprocessing.sequence.pad_sequences(x_train, maxlen, padding='post')x_test = keras.preprocessing.sequence.pad_sequences(x_test, maxlen, padding='post')print(x_train.shape, ' ', y_train.shape)print(x_test.shape, ' ', y_test.shape)
def RNN_model(): model = keras.Sequential([ layers.Embedding(input_dim=30000, output_dim=32, input_length=maxlen), layers.SimpleRNN(32, return_sequences=True), layers.SimpleRNN(1, activation='sigmoid', return_sequences=False) ]) model.compile(optimizer=keras.optimizers.Adam(), loss=keras.losses.BinaryCrossentropy(), metrics=['accuracy'])return model
model = RNN_model()model.summary()
history = model.fit(x_train, y_train, batch_size=64, epochs=5,validation_split=0.1)
- 一大拨漏洞来袭,eBAY的黑色星期五
- 基于 Docker 持续交付平台建设的实践
- Struts原理与实践
- 玩转WiFi Pineapple之看我如何优雅的盗取CMCC账号
- iOS多边形马赛克的实现(下)
- 见招拆招:绕过WAF继续SQL注入常用方法
- 从零开始在Python中实现决策树算法
- 走进科学:揭秘如何入侵电视机
- iOS多边形马赛克的实现(上)
- Android终端上视频转GIF的实现及GIF质量讨论
- Android手机上用户操作模拟方法的研究与实现
- Firefox内存释放重用漏洞高级利用(Pwn2Own2014、CVE-2014-1512)
- android 线程那点事
- android 向webview传值
- JavaScript 教程
- JavaScript 编辑工具
- JavaScript 与HTML
- JavaScript 与Java
- JavaScript 数据结构
- JavaScript 基本数据类型
- JavaScript 特殊数据类型
- JavaScript 运算符
- JavaScript typeof 运算符
- JavaScript 表达式
- JavaScript 类型转换
- JavaScript 基本语法
- JavaScript 注释
- Javascript 基本处理流程
- Javascript 选择结构
- Javascript if 语句
- Javascript if 语句的嵌套
- Javascript switch 语句
- Javascript 循环结构
- Javascript 循环结构实例
- Javascript 跳转语句
- Javascript 控制语句总结
- Javascript 函数介绍
- Javascript 函数的定义
- Javascript 函数调用
- Javascript 几种特殊的函数
- JavaScript 内置函数简介
- Javascript eval() 函数
- Javascript isFinite() 函数
- Javascript isNaN() 函数
- parseInt() 与 parseFloat()
- escape() 与 unescape()
- Javascript 字符串介绍
- Javascript length属性
- javascript 字符串函数
- Javascript 日期对象简介
- Javascript 日期对象用途
- Date 对象属性和方法
- Javascript 数组是什么
- Javascript 创建数组
- Javascript 数组赋值与取值
- Javascript 数组属性和方法
- 解决Android Studio Design界面不显示layout控件的问题
- Android读写文件工具类详解
- Kotlin实现在类里面创建main函数
- AndroidStudio 设置格式化断行宽度教程
- 从 SpringBoot 到 SpringMVC
- AndroidManifest.xml中含盖的安全问题详解
- Android Studio实现格式化XML代码顺序
- android自动生成dimens适配文件的图文教程详解(无需Java工具类)
- Android Studio自动提取控件Style样式教程
- 基于Android studio3.6的JNI教程之ncnn人脸检测mtcnn功能
- Kotlin 使用Lambda来设置回调的操作
- Kotlin之自定义 Live Templates详解(模板代码)
- Android Studio设置颜色拾色器工具Color Picker教程
- Kotlin中常见的符号详解
- Kotlin中实体类的创建方式