keras实战教程二(文本分类BiLSTM)
时间:2020-05-26
本文章向大家介绍keras实战教程二(文本分类BiLSTM),主要包括keras实战教程二(文本分类BiLSTM)使用实例、应用技巧、基本知识点总结和需要注意事项,具有一定的参考价值,需要的朋友可以参考一下。
什么是文本分类
给模型输入一句话,让模型判断这句话的类别(预定义)。
以文本情感分类为例
输入:的确是专业,用心做,出品方面都给好评。
输出:2
输出可以是[0,1,2]其中一个,0表示情感消极,1表示情感中性,2表示情感积极。
数据样式
网上应该能找到相关数据。
模型图
训练过程
仅仅作为测试训练一轮
代码
读取数据
import numpy as np from gensim.models.word2vec import Word2Vec from gensim.corpora.dictionary import Dictionary from gensim import models import pandas as pd import jieba import logging from keras import Sequential from keras.preprocessing.sequence import pad_sequences from keras.layers import Bidirectional,LSTM,Dense,Embedding,Dropout,Activation,Softmax from sklearn.model_selection import train_test_split from keras.utils import np_utils def read_data(data_path): senlist = [] labellist = [] with open(data_path, "r",encoding='gb2312',errors='ignore') as f: for data in f.readlines(): data = data.strip() sen = data.split("\t")[2] label = data.split("\t")[3] if sen != "" and (label =="0" or label=="1" or label=="2" ) : senlist.append(sen) labellist.append(label) else: pass assert(len(senlist) == len(labellist)) return senlist ,labellist sentences,labels = read_data("data_train.csv")
词向量
def train_word2vec(sentences,save_path): sentences_seg = [] sen_str = "\n".join(sentences) res = jieba.lcut(sen_str) seg_str = " ".join(res) sen_list = seg_str.split("\n") for i in sen_list: sentences_seg.append(i.split()) print("开始训练词向量") # logging.basicConfig(format='%(asctime)s : %(levelname)s : %(message)s', level=logging.INFO) model = Word2Vec(sentences_seg, size=100, # 词向量维度 min_count=5, # 词频阈值 window=5) # 窗口大小 model.save(save_path) return model model = train_word2vec(sentences,'word2vec.model')
数据处理
def generate_id2wec(word2vec_model): gensim_dict = Dictionary() gensim_dict.doc2bow(model.wv.vocab.keys(), allow_update=True) w2id = {v: k + 1 for k, v in gensim_dict.items()} # 词语的索引,从1开始编号 w2vec = {word: model[word] for word in w2id.keys()} # 词语的词向量 n_vocabs = len(w2id) + 1 embedding_weights = np.zeros((n_vocabs, 100)) for w, index in w2id.items(): # 从索引为1的词语开始,用词向量填充矩阵 embedding_weights[index, :] = w2vec[w] return w2id,embedding_weights def text_to_array(w2index, senlist): # 文本转为索引数字模式 sentences_array = [] for sen in senlist: new_sen = [ w2index.get(word,0) for word in sen] # 单词转索引数字 sentences_array.append(new_sen) return np.array(sentences_array) def prepare_data(w2id,sentences,labels,max_len=200): X_train, X_val, y_train, y_val = train_test_split(sentences,labels, test_size=0.2) X_train = text_to_array(w2id, X_train) X_val = text_to_array(w2id, X_val) X_train = pad_sequences(X_train, maxlen=max_len) X_val = pad_sequences(X_val, maxlen=max_len) return np.array(X_train), np_utils.to_categorical(y_train) ,np.array(X_val), np_utils.to_categorical(y_val)
w2id,embedding_weights = generate_id2wec(model)# 获取词向量矩阵和词典
x_train,y_trian, x_val , y_val = prepare_data(w2id,sentences,labels,200)#将数据处理成模型需要的格式
构建模型
class Sentiment: def __init__(self,w2id,embedding_weights,Embedding_dim,maxlen,labels_category): self.Embedding_dim = Embedding_dim self.embedding_weights = embedding_weights self.vocab = w2id self.labels_category = labels_category self.maxlen = maxlen self.model = self.build_model() def build_model(self): model = Sequential() #input dim(140,100) model.add(Embedding(output_dim = self.Embedding_dim, input_dim=len(self.vocab)+1, weights=[self.embedding_weights], input_length=self.maxlen)) model.add(Bidirectional(LSTM(50),merge_mode='concat')) model.add(Dropout(0.5)) model.add(Dense(self.labels_category)) model.add(Activation('softmax')) model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy']) model.summary() return model def train(self,X_train, y_train,X_test, y_test,n_epoch=5 ): self.model.fit(X_train, y_train, batch_size=32, epochs=n_epoch, validation_data=(X_test, y_test)) self.model.save('sentiment.h5') def predict(self,model_path,new_sen): model = self.model model.load_weights(model_path) new_sen_list = jieba.lcut(new_sen) sen2id =[ self.vocab.get(word,0) for word in new_sen_list] sen_input = pad_sequences([sen2id], maxlen=self.maxlen) res = model.predict(sen_input)[0] return np.argmax(res)
senti = Sentiment(w2id,embedding_weights,100,200,3)
训练预测
senti.train(x_train,y_trian, x_val ,y_val,1)#训练
label_dic = {0:"消极的",1:"中性的",2:"积极的"} sen_new = "现如今的公司能够做成这样已经很不错了,微订点单网站的信息更新很及时,内容来源很真实" pre = senti.predict("./sentiment.h5",sen_new) print("'{}'的情感是:\n{}".format(sen_new,label_dic.get(pre)))
参考https://www.jianshu.com/p/fba7df3a76fa
原文地址:https://www.cnblogs.com/pergrand/p/12967019.html
- 用pandas 进行投资分析
- 【专业技术】android 应用程序如何获取root权限
- Nginx+Keepalived(双机热备)搭建高可用负载均衡环境(HA)
- SpringMVC+MongoDB+Maven整合(微信回调Oauth授权)
- ZeroClipboard实现多个浏览器兼容的复制文本到剪贴板的功能
- Shiro 权限框架使用总结
- Apriori算法介绍(Python实现)
- linux学习第六十二篇:添加自定义监控项目,配置邮件告警,测试告警,不发邮件的问题处理
- Entity Framework Core 2.0 入门
- Nodejs开发框架Express3.0开发手记–从零开始
- 使用 nvm 管理不同版本的 node 与 npm
- svg矢量图绘制以及转换为Android可用的VectorDrawable资源
- CListCtrl控件使用方法总结
- JavaScript基础考核真题——你能全做对吗?
- JavaScript 教程
- JavaScript 编辑工具
- JavaScript 与HTML
- JavaScript 与Java
- JavaScript 数据结构
- JavaScript 基本数据类型
- JavaScript 特殊数据类型
- JavaScript 运算符
- JavaScript typeof 运算符
- JavaScript 表达式
- JavaScript 类型转换
- JavaScript 基本语法
- JavaScript 注释
- Javascript 基本处理流程
- Javascript 选择结构
- Javascript if 语句
- Javascript if 语句的嵌套
- Javascript switch 语句
- Javascript 循环结构
- Javascript 循环结构实例
- Javascript 跳转语句
- Javascript 控制语句总结
- Javascript 函数介绍
- Javascript 函数的定义
- Javascript 函数调用
- Javascript 几种特殊的函数
- JavaScript 内置函数简介
- Javascript eval() 函数
- Javascript isFinite() 函数
- Javascript isNaN() 函数
- parseInt() 与 parseFloat()
- escape() 与 unescape()
- Javascript 字符串介绍
- Javascript length属性
- javascript 字符串函数
- Javascript 日期对象简介
- Javascript 日期对象用途
- Date 对象属性和方法
- Javascript 数组是什么
- Javascript 创建数组
- Javascript 数组赋值与取值
- Javascript 数组属性和方法
- deepin linux 手动升级内核的方法
- UGL之单色位图
- Linux网络启动问题:Device does not seem to be present解决办法
- UGL之透明位图
- 关于ISR
- python 井字棋-文字版(下)
- Linux下nginx生成日志自动切割的实现方法
- Centos 7.2中双网卡绑定及相关问题踩坑记录
- Linux双网卡绑定实现负载均衡详解
- 单台服务器中利用Apache的VirtualHost如何搭建多个Web站点详解
- linux系统下MongoDB单节点安装教程
- Centos 7系统虚拟机桥接模式详解
- Centos 6中编译配置httpd2.4的多种方法详解
- linux的最大打开文件数限制修改方法
- Shell中如何删除文本比较长的行的实现方法