深度学总结:RNN训练需要注意地方:pytorch每一个batch训练之前需要把hidden = hidden.data,否者反向传播的梯度会遍历以前的timestep
时间:2019-02-21
本文章向大家介绍深度学总结:RNN训练需要注意地方:pytorch每一个batch训练之前需要把hidden = hidden.data,否者反向传播的梯度会遍历以前的timestep,主要包括深度学总结:RNN训练需要注意地方:pytorch每一个batch训练之前需要把hidden = hidden.data,否者反向传播的梯度会遍历以前的timestep使用实例、应用技巧、基本知识点总结和需要注意事项,具有一定的参考价值,需要的朋友可以参考一下。
pytorch每一个batch训练之前需要把hidden = hidden.data,否者反向传播的梯度会遍历以前的timestep
tensorflow也有把new_state更新,但是没有明显detach的操作,预计是tensorflow自己机制默认backpropagation一个timestep的梯度:
for e in range(epochs):
# Train network
new_state = sess.run(model.initial_state)
loss = 0
for x, y in get_batches(encoded, batch_size, num_steps):
counter += 1
start = time.time()
feed = {model.inputs: x,
model.targets: y,
model.keep_prob: keep_prob,
model.initial_state: new_state}
batch_loss, new_state, _ = sess.run([model.loss,
model.final_state,
model.optimizer],
feed_dict=feed)
pytorch每一个batch训练之前需要把hidden = hidden.data,否者反向传播的梯度会遍历以前的timestep,它是自动求导,需要专门把那个state提出来一下,这样就相当于detach了,反向梯度到这里就停止了。
# train the RNN
def train(rnn, n_steps, print_every):
# initialize the hidden state
hidden = None
for batch_i, step in enumerate(range(n_steps)):
# defining the training data
time_steps = np.linspace(step * np.pi, (step+1)*np.pi, seq_length + 1)
data = np.sin(time_steps)
data.resize((seq_length + 1, 1)) # input_size=1
x = data[:-1]
y = data[1:]
# convert data into Tensors
x_tensor = torch.Tensor(x).unsqueeze(0) # unsqueeze gives a 1, batch_size dimension
y_tensor = torch.Tensor(y)
# outputs from the rnn
prediction, hidden = rnn(x_tensor, hidden)
## Representing Memory ##
# make a new variable for hidden and detach the hidden state from its history
# this way, we don't backpropagate through the entire history
hidden = hidden.data
# calculate the loss
loss = criterion(prediction, y_tensor)
# zero gradients
optimizer.zero_grad()
# perform backprop and update weights
loss.backward()
optimizer.step()
# display loss and predictions
if batch_i%print_every == 0:
print('Loss: ', loss.item())
plt.plot(time_steps[1:], x, 'r.') # input
plt.plot(time_steps[1:], prediction.data.numpy().flatten(), 'b.') # predictions
plt.show()
return rnn
- 如何将Markdown文章轻松地搬运到微信公众号并完美地呈现代码内容
- IoC与AOP的那点事儿
- ossec入侵检测日志行为分析
- 从零开始的Spring Session(三)
- 从零开始的Spring Session(一)
- 一个通用的Java正则匹配工具
- 从零开始的Spring Session(二)
- [汇总]2013年度全球重、特大网络安全事件回顾
- android常用接口(一)
- 2014密码时代已死?六种旨在取代传统密码位置的新奇想法
- 程序员你为什么这么累【续】:编码习惯之配置规范
- Spring Security (一) Architecture Overview
- Spring Security (二) Guides
- 一个 android 的框架
- JavaScript 教程
- JavaScript 编辑工具
- JavaScript 与HTML
- JavaScript 与Java
- JavaScript 数据结构
- JavaScript 基本数据类型
- JavaScript 特殊数据类型
- JavaScript 运算符
- JavaScript typeof 运算符
- JavaScript 表达式
- JavaScript 类型转换
- JavaScript 基本语法
- JavaScript 注释
- Javascript 基本处理流程
- Javascript 选择结构
- Javascript if 语句
- Javascript if 语句的嵌套
- Javascript switch 语句
- Javascript 循环结构
- Javascript 循环结构实例
- Javascript 跳转语句
- Javascript 控制语句总结
- Javascript 函数介绍
- Javascript 函数的定义
- Javascript 函数调用
- Javascript 几种特殊的函数
- JavaScript 内置函数简介
- Javascript eval() 函数
- Javascript isFinite() 函数
- Javascript isNaN() 函数
- parseInt() 与 parseFloat()
- escape() 与 unescape()
- Javascript 字符串介绍
- Javascript length属性
- javascript 字符串函数
- Javascript 日期对象简介
- Javascript 日期对象用途
- Date 对象属性和方法
- Javascript 数组是什么
- Javascript 创建数组
- Javascript 数组赋值与取值
- Javascript 数组属性和方法
- 深入k8s:Pod对象中重要概念及用法
- Golang语言排序的几种方式
- 性能分析(1)- Java 进程导致 CPU 使用率升高,问题怎么定位?
- 安全服务之安全基线及加固(三)Apache篇
- 使用docsify来管理文献
- Cypress系列(41)- Cypress 的测试报告
- SSRF绕过
- 性能测试必备知识(6)- 如何查看“CPU 上下文切换”
- flex布局 div盒子居中
- 使用Apple Configurator 2提取商店ipa or app文件
- Spring 自动装配模式之byType
- 使用ATOMac进行Mac自动化测试
- 【赵渝强老师】什么是Oracle的数据字典?
- antd 如何在 src目录下 引入 Public 目录下的文件
- (精编)Python与安全(三)SSTI服务器模板注入