一看就懂的Tensorflow实战(LSTM)
时间:2022-07-22
本文章向大家介绍一看就懂的Tensorflow实战(LSTM),主要内容包括其使用实例、应用技巧、基本知识点总结和需要注意事项,具有一定的参考价值,需要的朋友可以参考一下。
LSTM 简介
公式 LSTM
LSTM
作为门控循环神经网络因此我们从门控单元切入理解。主要包括:
- 输入门:It
- 遗忘门:Ft
- 输出门:Ot
- 候选细胞:~Ct
- 细胞:Ct
- 隐含状态:Ht
假设隐含状态长度为h,数据Xt是一个样本数为n、特征向量维度为x的批量数据,其计算如下所示(W和b表示权重和偏置):
最后的输出其实只有两个,一个是输出,一个是状态,输出就是Ht,而状态为(Ct,Ht),其他都是中间计算过程。[2]
图示 LSTM
- 遗忘门
- 输入门
- 当前状态
- 输出层
Tensorflow LSTM
tensorflow 提供了LSTM 实现的一个 basic 版本,不包含 LSTM 的一些高级扩展,同时也提供了一个标准接口,其中包含了 LSTM 的扩展。分别为:tf.nn.rnn_cell.BasicLSTMCell(),tf.nn.rnn_cell.LSTMCell(),我们这里实现一个基本版本。[1]
Tensorflow 实现 LSTM
from __future__ import print_function
import tensorflow as tf
from tensorflow.contrib import rnn
导入数据集
# Import MNIST data
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("./data/", one_hot=True)
Extracting ./data/train-images-idx3-ubyte.gz
Extracting ./data/train-labels-idx1-ubyte.gz
Extracting ./data/t10k-images-idx3-ubyte.gz
Extracting ./data/t10k-labels-idx1-ubyte.gz
设置参数
# 训练参数
learning_rate = 0.001 # 学习率
training_steps = 10000 # 总迭代次数
batch_size = 128 # 批量大小
display_step = 200
# 网络参数
num_input = 28 # MNIST数据集图片: 28*28
timesteps = 28 # timesteps
num_hidden = 128 # 隐藏层神经元数
num_classes = 10 # MNIST 数据集类别数 (0-9 digits)
构建 LSTM 网络
# 定义输入
X = tf.placeholder("float", [None, timesteps, num_input])
Y = tf.placeholder("float", [None, num_classes])
# 定义权重和偏置
# weights矩阵[128, 10]
weights = {
'out': tf.Variable(tf.random_normal([num_hidden, num_classes]))
}
biases = {
'out': tf.Variable(tf.random_normal([num_classes]))
}
# 定义LSTM网络
def LSTM(x, weights, biases):
# Prepare data shape to match `rnn` function requirements
# 输入数据x的shape: (batch_size, timesteps, n_input)
# 需要的shape: 按 timesteps 切片,得到 timesteps 个 (batch_size, n_input)
# 对x进行切分
# tf.unstack(value,num=None,axis=0,name='unstack')
# value:要进行分割的tensor
# axis:整数,打算进行切分的维度
# num:整数,axis(打算切分)维度的长度
x = tf.unstack(x, timesteps, 1)
# 定义一个lstm cell,即上面图示LSTM中的A
# n_hidden表示神经元的个数,forget_bias就是LSTM们的忘记系数,如果等于1,就是不会忘记任何信息。如果等于0,就都忘记。
lstm_cell = rnn.BasicLSTMCell(num_hidden, forget_bias=1.0)
# 得到 lstm cell 输出
# 输出output和states
# outputs是一个长度为T的列表,通过outputs[-1]取出最后的输出
# state是最后的状态
outputs, states = rnn.static_rnn(lstm_cell, x, dtype=tf.float32)
# 线性激活
# 矩阵乘法
return tf.matmul(outputs[-1], weights['out']) + biases['out']
logits = LSTM(X, weights, biases)
prediction = tf.nn.softmax(logits)
# 定义损失函数和优化器
loss_op = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(
logits=logits, labels=Y))
optimizer = tf.train.GradientDescentOptimizer(learning_rate=learning_rate)
train_op = optimizer.minimize(loss_op)
# 模型评估(with test logits, for dropout to be disabled)
correct_pred = tf.equal(tf.argmax(prediction, 1), tf.argmax(Y, 1))
accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))
# 初始化全局变量
init = tf.global_variables_initializer()
训练+测试
# Start training
with tf.Session() as sess:
# Run the initializer
sess.run(init)
for step in range(1, training_steps+1):
batch_x, batch_y = mnist.train.next_batch(batch_size)
# Reshape data to get 28 seq of 28 elements
batch_x = batch_x.reshape((batch_size, timesteps, num_input))
# Run optimization op (backprop)
sess.run(train_op, feed_dict={X: batch_x, Y: batch_y})
if step % display_step == 0 or step == 1:
# Calculate batch loss and accuracy
loss, acc = sess.run([loss_op, accuracy], feed_dict={X: batch_x,
Y: batch_y})
print("Step " + str(step) + ", Minibatch Loss= " +
"{:.4f}".format(loss) + ", Training Accuracy= " +
"{:.3f}".format(acc))
print("Optimization Finished!")
# Calculate accuracy for 128 mnist test images
test_len = 128
test_data = mnist.test.images[:test_len].reshape((-1, timesteps, num_input))
test_label = mnist.test.labels[:test_len]
print("Testing Accuracy:",
sess.run(accuracy, feed_dict={X: test_data, Y: test_label}))
Step 1, Minibatch Loss= 2.8645, Training Accuracy= 0.062
Step 200, Minibatch Loss= 2.1180, Training Accuracy= 0.227
Step 400, Minibatch Loss= 1.9726, Training Accuracy= 0.344
Step 600, Minibatch Loss= 1.7784, Training Accuracy= 0.445
Step 800, Minibatch Loss= 1.5500, Training Accuracy= 0.547
Step 1000, Minibatch Loss= 1.5882, Training Accuracy= 0.453
Step 1200, Minibatch Loss= 1.5326, Training Accuracy= 0.555
Step 1400, Minibatch Loss= 1.3682, Training Accuracy= 0.570
Step 1600, Minibatch Loss= 1.3374, Training Accuracy= 0.594
Step 1800, Minibatch Loss= 1.1551, Training Accuracy= 0.648
Step 2000, Minibatch Loss= 1.2116, Training Accuracy= 0.633
Step 2200, Minibatch Loss= 1.1292, Training Accuracy= 0.609
Step 2400, Minibatch Loss= 1.0862, Training Accuracy= 0.680
Step 2600, Minibatch Loss= 1.0501, Training Accuracy= 0.672
Step 2800, Minibatch Loss= 1.0487, Training Accuracy= 0.688
Step 3000, Minibatch Loss= 1.0223, Training Accuracy= 0.727
Step 3200, Minibatch Loss= 1.0418, Training Accuracy= 0.695
Step 3400, Minibatch Loss= 0.8273, Training Accuracy= 0.719
Step 3600, Minibatch Loss= 0.9088, Training Accuracy= 0.727
Step 3800, Minibatch Loss= 0.9243, Training Accuracy= 0.750
Step 4000, Minibatch Loss= 0.8085, Training Accuracy= 0.703
Step 4200, Minibatch Loss= 0.8466, Training Accuracy= 0.711
Step 4400, Minibatch Loss= 0.8973, Training Accuracy= 0.734
Step 4600, Minibatch Loss= 0.7647, Training Accuracy= 0.750
Step 4800, Minibatch Loss= 0.9088, Training Accuracy= 0.742
Step 5000, Minibatch Loss= 0.7906, Training Accuracy= 0.742
Step 5200, Minibatch Loss= 0.7275, Training Accuracy= 0.781
Step 5400, Minibatch Loss= 0.7488, Training Accuracy= 0.789
Step 5600, Minibatch Loss= 0.7517, Training Accuracy= 0.758
Step 5800, Minibatch Loss= 0.7778, Training Accuracy= 0.797
Step 6000, Minibatch Loss= 0.6736, Training Accuracy= 0.742
Step 6200, Minibatch Loss= 0.6552, Training Accuracy= 0.773
Step 6400, Minibatch Loss= 0.5746, Training Accuracy= 0.828
Step 6600, Minibatch Loss= 0.8102, Training Accuracy= 0.727
Step 6800, Minibatch Loss= 0.6669, Training Accuracy= 0.773
Step 7000, Minibatch Loss= 0.6524, Training Accuracy= 0.766
Step 7200, Minibatch Loss= 0.6481, Training Accuracy= 0.805
Step 7400, Minibatch Loss= 0.5743, Training Accuracy= 0.828
Step 7600, Minibatch Loss= 0.6983, Training Accuracy= 0.773
Step 7800, Minibatch Loss= 0.5552, Training Accuracy= 0.828
Step 8000, Minibatch Loss= 0.5728, Training Accuracy= 0.820
Step 8200, Minibatch Loss= 0.5587, Training Accuracy= 0.789
Step 8400, Minibatch Loss= 0.5205, Training Accuracy= 0.836
Step 8600, Minibatch Loss= 0.4266, Training Accuracy= 0.906
Step 8800, Minibatch Loss= 0.7197, Training Accuracy= 0.812
Step 9000, Minibatch Loss= 0.4216, Training Accuracy= 0.852
Step 9200, Minibatch Loss= 0.4448, Training Accuracy= 0.844
Step 9400, Minibatch Loss= 0.3577, Training Accuracy= 0.891
Step 9600, Minibatch Loss= 0.4034, Training Accuracy= 0.883
Step 9800, Minibatch Loss= 0.4747, Training Accuracy= 0.828
Step 10000, Minibatch Loss= 0.5763, Training Accuracy= 0.805
Optimization Finished!
Testing Accuracy: 0.875
参考
[1] [tensorflow学习笔记(六):LSTM 与 GRU]https://blog.csdn.net/u012436149/article/details/52887091
[2] [学会区分 RNN 的 output 和 state]https://zhuanlan.zhihu.com/p/28919765
- JavaScript 教程
- JavaScript 编辑工具
- JavaScript 与HTML
- JavaScript 与Java
- JavaScript 数据结构
- JavaScript 基本数据类型
- JavaScript 特殊数据类型
- JavaScript 运算符
- JavaScript typeof 运算符
- JavaScript 表达式
- JavaScript 类型转换
- JavaScript 基本语法
- JavaScript 注释
- Javascript 基本处理流程
- Javascript 选择结构
- Javascript if 语句
- Javascript if 语句的嵌套
- Javascript switch 语句
- Javascript 循环结构
- Javascript 循环结构实例
- Javascript 跳转语句
- Javascript 控制语句总结
- Javascript 函数介绍
- Javascript 函数的定义
- Javascript 函数调用
- Javascript 几种特殊的函数
- JavaScript 内置函数简介
- Javascript eval() 函数
- Javascript isFinite() 函数
- Javascript isNaN() 函数
- parseInt() 与 parseFloat()
- escape() 与 unescape()
- Javascript 字符串介绍
- Javascript length属性
- javascript 字符串函数
- Javascript 日期对象简介
- Javascript 日期对象用途
- Date 对象属性和方法
- Javascript 数组是什么
- Javascript 创建数组
- Javascript 数组赋值与取值
- Javascript 数组属性和方法
- 【Vue.js】Vue.js中的事件处理、过滤器、过渡和动画、组件的生命周期及组件之间的通信
- 树莓派基础实验18:声音传感器实验
- 树莓派基础实验19:光敏传感器实验
- 逻辑式编程还有用吗?--“三维度”逻辑编程语言的设计(2)
- git 报错解决Validate branches Cannot Create: This merge request already exists
- 树莓派基础实验20:火焰报警传感器实验
- (译)SDL编程入门(8)几何图形渲染
- Java8 dubbo 调用 Collectors.toMap代码片发生的异常(IllegalStateException: Duplicate key)
- 树莓派基础实验21:烟雾报警传感器实验
- 树莓派基础实验22:红外遥控传感器实验
- Spring的BeanUtil的copyProperties方法 慎用!!
- (译)SDL编程入门(9)视口
- (译)SDL编程入门(7)纹理加载和渲染
- 三步带你开发一个短链接生成平台
- 绕安全狗的那些事