Keras图像数据预处理范例——Cifar2图片分类
时间:2022-07-22
本文章向大家介绍Keras图像数据预处理范例——Cifar2图片分类,主要内容包括其使用实例、应用技巧、基本知识点总结和需要注意事项,具有一定的参考价值,需要的朋友可以参考一下。
本文将以Cifar2数据集为范例,介绍Keras对图片数据进行预处理并喂入神经网络模型的方法。
Cifar2数据集为Cifar10数据集的子集,只包括前两种类别airplane和automobile。
训练集有airplane和automobile图片各5000张,测试集有airplane和automobile图片各1000张。
我们将重点介绍Keras中可以对图片进行数据增强的ImageDataGenerator工具和对内存友好的训练方法fit_generator的使用。让我们出发吧!
一,准备数据
1,获取数据
公众号后台回复关键字:Cifar2,可以获得Cifar2数据集下载链接,数据大约10M,解压后约1.5G。
我们准备的Cifar2数据集的文件结构如下所示。
直观感受一下。
2,数据增强
利用keras中的图片数据预处理工具ImageDataGenerator我们可以轻松地对训练集图片数据设置旋转,翻转,缩放等数据增强。
from keras.preprocessing.image import ImageDataGenerator
train_dir = 'cifar2_datasets/train'
test_dir = 'cifar2_datasets/test'
# 对训练集数据设置数据增强
train_datagen = ImageDataGenerator(
rescale = 1./,
rotation_range=,
width_shift_range=0.2,
height_shift_range=0.2,
shear_range=0.2,
zoom_range=0.2,
horizontal_flip=True,
fill_mode='nearest')
# 对测试集数据无需使用数据增强
test_datagen = ImageDataGenerator(rescale=1./)
数据增强相关参数说明:
- rotation_range是角度值(在 0~180 范围内),表示图像随机旋转的角度范围。
- width_shift 和 height_shift 是图像在水平或垂直方向上平移的范围(相对于总宽 度或总高度的比例)。
- shear_range是随机错切变换的角度。
- zoom_range是图像随机缩放的范围。
- horizontal_flip 是随机将一半图像水平翻转。如果没有水平不对称的假设(比如真 实世界的图像),这种做法是有意义的。
- fill_mode是用于填充新创建像素的方法,这些新像素可能来自于旋转或宽度/高度平移。
查看数据增强效果
import os
from keras.preprocessing import image
from matplotlib import pyplot as plt
%matplotlib inline
%config InlineBackend.figure_format = 'png'
fnames = [os.path.join('cifar2_datasets/train/0_airplane', fname) for
fname in os.listdir('cifar2_datasets/train/0_airplane')]
# 载入第3张图像
img_path = fnames[]
img = image.load_img(img_path, target_size=(, ))
x = image.img_to_array(img)
plt.figure(,figsize = (,))
plt.subplot(,,)
plt.imshow(image.array_to_img(x))
plt.title('original image')
# 数据增强后的图像
x = x.reshape((,) + x.shape)
i =
for batch in train_datagen.flow(x, batch_size=):
plt.subplot(,,i+)
plt.imshow(image.array_to_img(batch[]))
plt.title('after augumentation %d'%(i+))
i = i +
if i % == :
break
plt.show()
3,导入数据
使用ImageDataGenerator的flow_from_directory方法可以从文件夹中导入图片数据,转换成固定尺寸的张量,这个方法将得到一个可以读取图片数据的生成器generator。
train_generator = train_datagen.flow_from_directory(
train_dir,
target_size=(, ),
batch_size=,
shuffle = True,
class_mode='binary')
test_generator = test_datagen.flow_from_directory(
test_dir,
target_size=(, ),
batch_size=,
shuffle = False,
class_mode='binary')
print(train_generator.class_indices)
二,构建模型
from keras import models,layers,optimizers
from keras import backend as K
K.clear_session()
model = models.Sequential()
model.add(layers.Flatten(input_shape = (,,)))
model.add(layers.Dense(, activation='relu'))
model.add(layers.Dense(, activation='sigmoid'))
model.compile(loss='binary_crossentropy',
optimizer=optimizers.RMSprop(lr=1e-4),
metrics=['acc'])
model.summary()
三,训练模型
# 计算每轮次需要的步数
import numpy as np
train_steps_per_epoch = np.ceil(/)
test_steps_per_epoch = np.ceil(/)
# 使用内存友好的fit_generator方法进行训练
history = model.fit_generator(
train_generator,
steps_per_epoch = train_steps_per_epoch,
epochs = ,
validation_data= test_generator,
validation_steps=test_steps_per_epoch,
workers=, # 读取数据的进程数
use_multiprocessing=False #linux上可使用多进程读取数据
)
四,评估模型
五,使用模型
from sklearn.metrics import roc_auc_score
test_datagen = ImageDataGenerator(rescale=1./)
# 注意,使用模型进行预测时要设置生成器shuffle = False
test_generator = test_datagen.flow_from_directory(
test_dir,
target_size=(, ),
batch_size=,
class_mode='binary',
shuffle = False)
# 计算auc
y_pred = model.predict_generator(test_generator,steps = len(test_generator))
y_pred = np.reshape(y_pred,(-1,))
y_true = np.concatenate([test_generator[i][]
for i in range(len(test_generator))])
auc = roc_auc_score(y_true,y_pred)
print('test auc:',auc)
六,保存模型
model.save('cifar2_model.h5')
- 程序员你为什么这么累【续】:编码习惯-函数编写建议
- 那些年,我们一起碰到过的骗局
- Spring Security (五) 动手实现一个IP_Login
- 史上最全Linux提权后获取敏感信息方法
- Spring Security (四) 核心过滤器源码分析
- Spring Security (三) 核心配置解读
- Spring Cloud配置中心获取不到最新配置信息的问题
- 总是听别人说响应式布局,原来这么简单
- Spring Cloud Zuul重试机制探秘
- Eureka中RetryableClientQuarantineRefreshPercentage参数探秘
- Edgware.RC1中ZuulFallbackProvider的改进
- JPA的多表复杂查询:详细篇
- 尝试使用Memcached遇到的狗血问题
- Enumerable#Zip 实现一下
- JavaScript 教程
- JavaScript 编辑工具
- JavaScript 与HTML
- JavaScript 与Java
- JavaScript 数据结构
- JavaScript 基本数据类型
- JavaScript 特殊数据类型
- JavaScript 运算符
- JavaScript typeof 运算符
- JavaScript 表达式
- JavaScript 类型转换
- JavaScript 基本语法
- JavaScript 注释
- Javascript 基本处理流程
- Javascript 选择结构
- Javascript if 语句
- Javascript if 语句的嵌套
- Javascript switch 语句
- Javascript 循环结构
- Javascript 循环结构实例
- Javascript 跳转语句
- Javascript 控制语句总结
- Javascript 函数介绍
- Javascript 函数的定义
- Javascript 函数调用
- Javascript 几种特殊的函数
- JavaScript 内置函数简介
- Javascript eval() 函数
- Javascript isFinite() 函数
- Javascript isNaN() 函数
- parseInt() 与 parseFloat()
- escape() 与 unescape()
- Javascript 字符串介绍
- Javascript length属性
- javascript 字符串函数
- Javascript 日期对象简介
- Javascript 日期对象用途
- Date 对象属性和方法
- Javascript 数组是什么
- Javascript 创建数组
- Javascript 数组赋值与取值
- Javascript 数组属性和方法
- HTTP服务器Nginx服务介绍续
- python第四十六课——函数重写
- Linux系统Memcached服务介绍
- python第四十七课——类属性和函数属性
- python第四十八课——类函数和对象函数
- python第四十九课——对象序列化与反序列化
- python第五十课——多态性
- python第五十一课——__slots
- Linux系统安全配置iptables服务介绍
- ThreadLocal企业中真实应用
- python第五十二课--自定义异常类
- python第五十三课——time模块
- 从亲身经历谈谈如何用Git分支解决项目生产实践中的痛点
- mysql数据库基础命令(一)
- Linux系统Logrotate服务介绍