在pytorch中动态调整优化器的学习率方式
时间:2022-07-27
本文章向大家介绍在pytorch中动态调整优化器的学习率方式,主要内容包括其使用实例、应用技巧、基本知识点总结和需要注意事项,具有一定的参考价值,需要的朋友可以参考一下。
在深度学习中,经常需要动态调整学习率,以达到更好地训练效果,本文纪录在pytorch中的实现方法,其优化器实例为SGD优化器,其他如Adam优化器同样适用。
一般来说,在以SGD优化器作为基本优化器,然后根据epoch实现学习率指数下降,代码如下:
step = [10,20,30,40]
base_lr = 1e-4
sgd_opt = torch.optim.SGD(model.parameters(), lr=base_lr, nesterov=True, momentum=0.9)
def adjust_lr(epoch):
lr = base_lr * (0.1 ** np.sum(epoch = np.array(step)))
for params_group in sgd_opt.param_groups:
params_group['lr'] = lr
return lr
只需要在每个train的epoch之前使用这个函数即可。
for epoch in range(60):
model.train()
adjust_lr(epoch)
for ind, each in enumerate(train_loader):
mat, label = each
...
补充知识:Pytorch框架下应用Bi-LSTM实现汽车评论文本关键词抽取
需要调用的模块及整体Bi-lstm流程
import torch
import pandas as pd
import numpy as np
from tensorflow import keras
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
from torch.utils.data import TensorDataset
import gensim
from sklearn.model_selection import train_test_split
class word_extract(nn.Module):
def __init__(self,d_model,embedding_matrix):
super(word_extract, self).__init__()
self.d_model=d_model
self.embedding=nn.Embedding(num_embeddings=len(embedding_matrix),embedding_dim=200)
self.embedding.weight.data.copy_(embedding_matrix)
self.embedding.weight.requires_grad=False
self.lstm1=nn.LSTM(input_size=200,hidden_size=50,bidirectional=True)
self.lstm2=nn.LSTM(input_size=2*self.lstm1.hidden_size,hidden_size=50,bidirectional=True)
self.linear=nn.Linear(2*self.lstm2.hidden_size,4)
def forward(self,x):
w_x=self.embedding(x)
first_x,(first_h_x,first_c_x)=self.lstm1(w_x)
second_x,(second_h_x,second_c_x)=self.lstm2(first_x)
output_x=self.linear(second_x)
return output_x
将文本转换为数值形式
def trans_num(word2idx,text):
text_list=[]
for i in text:
s=i.rstrip().replace('r','').replace('n','').split(' ')
numtext=[word2idx[j] if j in word2idx.keys() else word2idx['_PAD'] for j in s ]
text_list.append(numtext)
return text_list
将Gensim里的词向量模型转为矩阵形式,后续导入到LSTM模型中
def establish_word2vec_matrix(model): #负责将数值索引转为要输入的数据
word2idx = {"_PAD": 0} # 初始化 `[word : token]` 字典,后期 tokenize 语料库就是用该词典。
num2idx = {0: "_PAD"}
vocab_list = [(k, model.wv[k]) for k, v in model.wv.vocab.items()]
# 存储所有 word2vec 中所有向量的数组,留意其中多一位,词向量全为 0, 用于 padding
embeddings_matrix = np.zeros((len(model.wv.vocab.items()) + 1, model.vector_size))
for i in range(len(vocab_list)):
word = vocab_list[i][0]
word2idx[word] = i + 1
num2idx[i + 1] = word
embeddings_matrix[i + 1] = vocab_list[i][1]
embeddings_matrix = torch.Tensor(embeddings_matrix)
return embeddings_matrix, word2idx, num2idx
训练过程
def train(model,epoch,learning_rate,batch_size,x, y, val_x, val_y):
optimizor = optim.Adam(model.parameters(), lr=learning_rate)
data = TensorDataset(x, y)
data = DataLoader(data, batch_size=batch_size)
for i in range(epoch):
for j, (per_x, per_y) in enumerate(data):
output_y = model(per_x)
loss = F.cross_entropy(output_y.view(-1,output_y.size(2)), per_y.view(-1))
optimizor.zero_grad()
loss.backward()
optimizor.step()
arg_y=output_y.argmax(dim=2)
fit_correct=(arg_y==per_y).sum()
fit_acc=fit_correct.item()/(per_y.size(0)*per_y.size(1))
print('##################################')
print('第{}次迭代第{}批次的训练误差为{}'.format(i + 1, j + 1, loss), end=' ')
print('第{}次迭代第{}批次的训练准确度为{}'.format(i + 1, j + 1, fit_acc))
val_output_y = model(val_x)
val_loss = F.cross_entropy(val_output_y.view(-1,val_output_y.size(2)), val_y.view(-1))
arg_val_y=val_output_y.argmax(dim=2)
val_correct=(arg_val_y==val_y).sum()
val_acc=val_correct.item()/(val_y.size(0)*val_y.size(1))
print('第{}次迭代第{}批次的预测误差为{}'.format(i + 1, j + 1, val_loss), end=' ')
print('第{}次迭代第{}批次的预测准确度为{}'.format(i + 1, j + 1, val_acc))
torch.save(model,'./extract_model.pkl')#保存模型
主函数部分
if __name__ =='__main__':
#生成词向量矩阵
word2vec = gensim.models.Word2Vec.load('./word2vec_model')
embedding_matrix,word2idx,num2idx=establish_word2vec_matrix(word2vec)#输入的是词向量模型
#
train_data=pd.read_csv('./数据.csv')
x=list(train_data['文本'])
# 将文本从文字转化为数值,这部分trans_num函数你需要自己改动去适应你自己的数据集
x=trans_num(word2idx,x)
#x需要先进行填充,也就是每个句子都是一样长度,不够长度的以0来填充,填充词单独分为一类
# #也就是说输入的x是固定长度的数值列表,例如[50,123,1850,21,199,0,0,...]
#输入的y是[2,0,1,0,0,1,3,3,3,3,3,.....]
#填充代码你自行编写,以下部分是针对我的数据集
x=keras.preprocessing.sequence.pad_sequences(
x,maxlen=60,value=0,padding='post',
)
y=list(train_data['BIO数值'])
y_text=[]
for i in y:
s=i.rstrip().split(' ')
numtext=[int(j) for j in s]
y_text.append(numtext)
y=y_text
y=keras.preprocessing.sequence.pad_sequences(
y,maxlen=60,value=3,padding='post',
)
# 将数据进行划分
fit_x,val_x,fit_y,val_y=train_test_split(x,y,train_size=0.8,test_size=0.2)
fit_x=torch.LongTensor(fit_x)
fit_y=torch.LongTensor(fit_y)
val_x=torch.LongTensor(val_x)
val_y=torch.LongTensor(val_y)
#开始应用
w_extract=word_extract(d_model=200,embedding_matrix=embedding_matrix)
train(model=w_extract,epoch=5,learning_rate=0.001,batch_size=50,
x=fit_x,y=fit_y,val_x=val_x,val_y=val_y)#可以自行改动参数,设置学习率,批次,和迭代次数
w_extract=torch.load('./extract_model.pkl')#加载保存好的模型
pred_val_y=w_extract(val_x).argmax(dim=2)
以上这篇在pytorch中动态调整优化器的学习率方式就是小编分享给大家的全部内容了,希望能给大家一个参考。
- Gis链接
- TortoiseSVN文件夹及文件图标不显示解决方法 TortoiseSVN文件夹及文件图标不显示解决方法
- 地图坐标
- PowerDesigner15连接Oracle失败的解决办法
- 地图校正方法心得
- 工作流参考模型点评
- 按图索骥:SQL中数据倾斜问题的处理思路与方法
- [方法“Boolean Contains(System.Guid)”不支持转换为 SQL]的解决办法
- DataBind的一些试验
- 继承HibernateDaoSupport时遇到的问题 使用注解为HibernateDaoSupport注入sessionFa
- 常用代码
- 小程序的新功能你知道吗
- Mapxtreme之活活气死
- 仿淘宝的交易到计时JS
- JavaScript 教程
- JavaScript 编辑工具
- JavaScript 与HTML
- JavaScript 与Java
- JavaScript 数据结构
- JavaScript 基本数据类型
- JavaScript 特殊数据类型
- JavaScript 运算符
- JavaScript typeof 运算符
- JavaScript 表达式
- JavaScript 类型转换
- JavaScript 基本语法
- JavaScript 注释
- Javascript 基本处理流程
- Javascript 选择结构
- Javascript if 语句
- Javascript if 语句的嵌套
- Javascript switch 语句
- Javascript 循环结构
- Javascript 循环结构实例
- Javascript 跳转语句
- Javascript 控制语句总结
- Javascript 函数介绍
- Javascript 函数的定义
- Javascript 函数调用
- Javascript 几种特殊的函数
- JavaScript 内置函数简介
- Javascript eval() 函数
- Javascript isFinite() 函数
- Javascript isNaN() 函数
- parseInt() 与 parseFloat()
- escape() 与 unescape()
- Javascript 字符串介绍
- Javascript length属性
- javascript 字符串函数
- Javascript 日期对象简介
- Javascript 日期对象用途
- Date 对象属性和方法
- Javascript 数组是什么
- Javascript 创建数组
- Javascript 数组赋值与取值
- Javascript 数组属性和方法
- Python | 从 PDF 中提取文本内容
- Stata | 自动生成中南财大2019拟录取硕士研究生分析报告
- Stata | 聊聊数据排序的几种方式
- 在生产中应用广泛的排序算法
- SQL | SQL 必知必会笔记 (一 )
- 如何在树莓派4B上安装EMQ X Broker
- SQL | SQL 必知必会笔记 (二)
- 基于桶子法实现的两种排序算法
- Notes | 微观经济学课堂笔记(一)
- 将终结点图添加到你的ASP.NET Core应用程序中
- Stata | 爬取 CFPS 文献传送门并制作成 Markdown
- 委托的好处
- Elasticsearch安装和配置
- Notes | QUAIDS 模型
- Stata | 520,听说你也想快点找到...