Keras文本数据预处理范例——IMDB影评情感分类
时间:2022-07-22
本文章向大家介绍Keras文本数据预处理范例——IMDB影评情感分类,主要内容包括其使用实例、应用技巧、基本知识点总结和需要注意事项,具有一定的参考价值,需要的朋友可以参考一下。
本文将以IMDB电影评论数据集为范例,介绍Keras对文本数据预处理并喂入神经网络模型的方法。
IMDB数据集的目标是根据电影评论的文本内容预测评论的情感标签。训练集有20000条电影评论文本,测试集有5000条电影评论文本,其中正面评论和负面评论都各占一半。
文本数据预处理主要包括中文切词(本示例不涉及),构建词典,序列填充,定义数据管道等步骤。让我们出发吧!
一,准备数据
1,获取数据
在公众号后台回复关键字:imdb,可以获取IMDB数据集的下载链接。数据大小约为13M,解压后约为31M。
数据集结构如下所示。
直观感受一下文本内容。
2,构建词典
为了能够将文本数据喂入模型,我们一般要构建词典,以便将词转换成对应的token(即数字编码)。
from keras.preprocessing.text import Tokenizer
from tqdm import tqdm
# 数据集路径
train_data_path = 'imdb_datasets/xx_train_imdb'
test_data_path = 'imdb_datasets/xx_test_imdb'
train_samples = #训练集样本数量
test_samples = #测试集样本数量
max_words = # 保留词频最高的前10000个词
maxlen = # 每个样本文本内容最多保留500个词
# 构建训练集文本生成器
def texts_gen():
with open(train_data_path,'r',encoding = 'utf-8') as f,
tqdm(total = train_samples) as pbar:
while True:
text = (f.readline().rstrip('n').split('t')[-1])
if not text:
break
if len(text) > maxlen:
text = text[:maxlen]
pbar.update()
yield text
texts = texts_gen()
tokenizer = Tokenizer(num_words=max_words)
tokenizer.fit_on_texts(texts)
看一下我们生成的词典。
3,分割样本
为了能够像ImageDataGenerator那样用数据管道多进程并行地读取数据,我们需要将数据集按样本分割成多个文件。
import os
scatter_train_data_path = 'imdb_datasets/train/'
scatter_test_data_path = 'imdb_datasets/test/'
# 将数据按样本打散到多个文件
def scatter_data(data_file, scatter_data_path):
if not os.path.exists(scatter_data_path):
os.makedirs(scatter_data_path)
for idx,line in tqdm(enumerate(open(data_file,'r',encoding = 'utf-8'))):
with open(scatter_data_path + str(idx) + '.txt','w',
encoding = 'utf-8') as f:
f.write(line)
scatter_data(train_data_path,scatter_train_data_path)
scatter_data(test_data_path,scatter_test_data_path)
4,定义管道
通过继承keras.utils.Sequence类,我们可以构建像ImageDataGenerator那样能够并行读取数据的生成器管道。尽管下面的代码看起来有些长,但通常只有__data_generation方法需要被修改。
# 定义Sequence数据管道, 可以多线程读数据
import keras
import numpy as np
from keras.preprocessing.sequence import pad_sequences
batch_size =
class DataGenerator(keras.utils.Sequence):
def __init__(self,n_samples,data_path,batch_size=batch_size,shuffle=True):
self.data_path = data_path
self.n_samples = n_samples
self.batch_size = batch_size
self.shuffle = shuffle
self.on_epoch_end()
def __len__(self):
return int(np.ceil(self.n_samples/self.batch_size))
def __getitem__(self, index):
# Generate indexes of the batch
batch_indexes = self.indexes[index*self.batch_size:(index+)*self.batch_size]
# Generate data
datas, labels = self.__data_generation(batch_indexes)
return datas, labels
def on_epoch_end(self):
self.indexes = np.arange(self.n_samples)
if self.shuffle == True:
np.random.shuffle(self.indexes)
def __read_file(self,file_name):
with open(file_name,encoding = 'utf-8') as f:
line = f.readline()
return line
def __data_generation(self, batch_indexes):
lines = [self.__read_file(self.data_path + str(i) + '.txt') for i in batch_indexes]
labels = np.array([int(line.strip().split('t')[]) for line in lines])
texts = [line.strip().split('t')[-1] for line in lines]
sequences = tokenizer.texts_to_sequences(texts)
datas = pad_sequences(sequences,maxlen)
return datas,labels
train_gen = DataGenerator(train_samples,scatter_train_data_path)
test_gen = DataGenerator(test_samples,scatter_test_data_path)
二,构建模型
为了将文本token后的整数序列用神经网络进行处理,我们在第一层使用了Embedding层,Embedding层从数学上等效为将输入数据进行onehot编码后的一个全连接层,在形式上以查表方式实现以提升效率。
from keras import models,layers
from keras import backend as K
K.clear_session()
embedding_dim =
model = models.Sequential()
model.add(layers.Embedding(max_words, embedding_dim, input_length=maxlen))
model.add(layers.Flatten())
model.add(layers.Dense(,activation = 'relu'))
model.add(layers.Dense(, activation = 'sigmoid'))
model.summary()
model.compile(optimizer='adam',
loss='binary_crossentropy',
metrics=['acc'])
三,训练模型
epoch_num =
steps_per_epoch = int(np.ceil(train_samples/batch_size))
validation_steps = int(np.ceil(test_samples/batch_size))
history = model.fit_generator(train_gen,
steps_per_epoch = steps_per_epoch,
epochs = epoch_num,
validation_data= test_gen,
validation_steps = validation_steps,
workers=,
use_multiprocessing=False #linux上可使用多进程读取数据
)
四,评估模型
import os
import pandas as pd
# 保存得分
acc = history.history['acc']
val_acc = history.history['val_acc']
loss = history.history['loss']
val_loss = history.history['val_loss']
epochs = range(, len(acc) + )
dfhistory = pd.DataFrame({'epoch':epochs,'train_loss':loss,'test_loss':val_loss,
'train_acc':acc,'test_acc':val_acc})
print(dfhistory)
五,使用模型
六,保存模型
model.save('imdb_model.h5')
- 走近科学:我是如何入侵Instagram查看你的私人片片的
- 在线手写识别的多卷积神经网络方法
- 苹果发布OS X 10.9.2更新,修复SSL漏洞
- Android内存泄漏终极解决篇(下)
- 利用Volatility查找系统中的恶意DLL
- 雪人行动:利用IE10 0day漏洞的APT攻击剑指美国军方情报
- Android开发:最详细的 Toolbar 开发实践总结
- 关于yubikey对web应用的杞人之忧
- 利用旧版Android漏洞的E-Z-2-Use攻击代码已在Metasploit发布
- Android Studio你不知道的调试技巧
- Android 数据绑定框架DataBinding,堪称解决界面逻辑的黑科技
- 汽车黑客:没有Security就没有Safety
- Android 自定义View高级特效,神奇的贝塞尔曲线
- Android二维码扫描开发(一):实现思路与原理
- JavaScript 教程
- JavaScript 编辑工具
- JavaScript 与HTML
- JavaScript 与Java
- JavaScript 数据结构
- JavaScript 基本数据类型
- JavaScript 特殊数据类型
- JavaScript 运算符
- JavaScript typeof 运算符
- JavaScript 表达式
- JavaScript 类型转换
- JavaScript 基本语法
- JavaScript 注释
- Javascript 基本处理流程
- Javascript 选择结构
- Javascript if 语句
- Javascript if 语句的嵌套
- Javascript switch 语句
- Javascript 循环结构
- Javascript 循环结构实例
- Javascript 跳转语句
- Javascript 控制语句总结
- Javascript 函数介绍
- Javascript 函数的定义
- Javascript 函数调用
- Javascript 几种特殊的函数
- JavaScript 内置函数简介
- Javascript eval() 函数
- Javascript isFinite() 函数
- Javascript isNaN() 函数
- parseInt() 与 parseFloat()
- escape() 与 unescape()
- Javascript 字符串介绍
- Javascript length属性
- javascript 字符串函数
- Javascript 日期对象简介
- Javascript 日期对象用途
- Date 对象属性和方法
- Javascript 数组是什么
- Javascript 创建数组
- Javascript 数组赋值与取值
- Javascript 数组属性和方法
- 手写一个JDK1.7的简版HashMap
- MySQL存储过程创建与使用
- 一些有意思的JavaScript代码片段
- Flutter 完成全平台制霸:实现 Windows 应用支持
- python包:urllib——使用urllib下载无限制链接图片
- 初探 TensorFlow.js
- 如何使用 Apache Directory Studio 连接 JumpCloud
- 0812-5.16.2-如何获取CDSW上提交Spark作业的真实用户
- GLMM:广义线性混合模型(遗传参数评估)
- 特征锦囊:今天一起搞懂机器学习里的L1与L2正则化
- 【一天一大 lee】二叉搜索树的最近公共祖先 (难度:简单) - Day2020092
- Spring多数据源事务如何玩? | Spring系列46篇
- 使用Mfuzz包做时间序列分析
- 网络安全 | 瑞哥带你全方位解读防火墙技术!
- 【SpringBoot DB 系列】Jooq 之新增记录使用姿势