【tensorflow2.0】处理图片数据-cifar2分类
1、准备数据
cifar2数据集为cifar10数据集的子集,只包括前两种类别airplane和automobile。
训练集有airplane和automobile图片各5000张,测试集有airplane和automobile图片各1000张。
cifar2任务的目标是训练一个模型来对飞机airplane和机动车automobile两种图片进行分类。
我们准备的Cifar2数据集的文件结构如下所示。
在tensorflow中准备图片数据的常用方案有两种,第一种是使用tf.keras中的ImageDataGenerator工具构建图片数据生成器。
第二种是使用tf.data.Dataset搭配tf.image中的一些图片处理方法构建数据管道。
第一种方法更为简单,其使用范例可以参考以下文章。
https://zhuanlan.zhihu.com/p/67466552
第二种方法是TensorFlow的原生方法,更加灵活,使用得当的话也可以获得更好的性能。
我们此处介绍第二种方法。
import tensorflow as tf
from tensorflow.keras import datasets,layers,models
BATCH_SIZE = 100
def load_image(img_path,size = (32,32)):
label = tf.constant(1,tf.int8) if tf.strings.regex_full_match(img_path,".*/automobile/.*")
else tf.constant(0,tf.int8)
img = tf.io.read_file(img_path)
img = tf.image.decode_jpeg(img) #注意此处为jpeg格式
img = tf.image.resize(img,size)/255.0
return(img,label)
# 使用并行化预处理num_parallel_calls 和预存数据prefetch来提升性能
ds_train = tf.data.Dataset.list_files("./data/cifar2/train/*/*.jpg")
.map(load_image, num_parallel_calls=tf.data.experimental.AUTOTUNE)
.shuffle(buffer_size = 1000).batch(BATCH_SIZE)
.prefetch(tf.data.experimental.AUTOTUNE)
ds_test = tf.data.Dataset.list_files("./data/cifar2/test/*/*.jpg")
.map(load_image, num_parallel_calls=tf.data.experimental.AUTOTUNE)
.batch(BATCH_SIZE)
.prefetch(tf.data.experimental.AUTOTUNE)
for x,y in ds_train.take(1):
print(x.shape,y.shape)
(100, 32, 32, 3) (100,)
2、定义模型
使用Keras接口有以下3种方式构建模型:使用Sequential按层顺序构建模型,使用函数式API构建任意结构模型,继承Model基类构建自定义模型。
此处选择使用函数式API构建模型。
tf.keras.backend.clear_session() #清空会话
inputs = layers.Input(shape=(32,32,3))
x = layers.Conv2D(32,kernel_size=(3,3))(inputs)
x = layers.MaxPool2D()(x)
x = layers.Conv2D(64,kernel_size=(5,5))(x)
x = layers.MaxPool2D()(x)
x = layers.Dropout(rate=0.1)(x)
x = layers.Flatten()(x)
x = layers.Dense(32,activation='relu')(x)
outputs = layers.Dense(1,activation = 'sigmoid')(x)
model = models.Model(inputs = inputs,outputs = outputs)
model.summary()
3、训练模型
训练模型通常有3种方法,内置fit方法,内置train_on_batch方法,以及自定义训练循环。此处我们选择最常用也最简单的内置fit方法。
import datetime
logdir = "./data/keras_model/" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
tensorboard_callback = tf.keras.callbacks.TensorBoard(logdir, histogram_freq=1)
model.compile(
optimizer=tf.keras.optimizers.Adam(learning_rate=0.001),
loss=tf.keras.losses.binary_crossentropy,
metrics=["accuracy"]
)
history = model.fit(ds_train,epochs= 10,validation_data=ds_test,
callbacks = [tensorboard_callback],workers = 4)
Epoch 1/10
100/100 [==============================] - 2205s 22s/step - loss: 0.4632 - accuracy: 0.7786 - val_loss: 0.3375 - val_accuracy: 0.8620
Epoch 2/10
100/100 [==============================] - 11s 110ms/step - loss: 0.3346 - accuracy: 0.8565 - val_loss: 0.2617 - val_accuracy: 0.8965
Epoch 3/10
100/100 [==============================] - 11s 111ms/step - loss: 0.2687 - accuracy: 0.8883 - val_loss: 0.2183 - val_accuracy: 0.9165
Epoch 4/10
100/100 [==============================] - 11s 110ms/step - loss: 0.2171 - accuracy: 0.9128 - val_loss: 0.1811 - val_accuracy: 0.9280
Epoch 5/10
100/100 [==============================] - 11s 114ms/step - loss: 0.1860 - accuracy: 0.9268 - val_loss: 0.1798 - val_accuracy: 0.9265
Epoch 6/10
100/100 [==============================] - 11s 112ms/step - loss: 0.1646 - accuracy: 0.9358 - val_loss: 0.1818 - val_accuracy: 0.9260
Epoch 7/10
100/100 [==============================] - 11s 113ms/step - loss: 0.1443 - accuracy: 0.9426 - val_loss: 0.1740 - val_accuracy: 0.9290
Epoch 8/10
100/100 [==============================] - 11s 113ms/step - loss: 0.1301 - accuracy: 0.9469 - val_loss: 0.1635 - val_accuracy: 0.9325
Epoch 9/10
100/100 [==============================] - 11s 112ms/step - loss: 0.1096 - accuracy: 0.9585 - val_loss: 0.1758 - val_accuracy: 0.9315
Epoch 10/10
100/100 [==============================] - 11s 113ms/step - loss: 0.0961 - accuracy: 0.9628 - val_loss: 0.1595 - val_accuracy: 0.9415
4、评估模型
# %load_ext tensorboard
# %tensorboard --logdir ./data/keras_model
from tensorboard import notebook
notebook.list()
# 在tensorboard中查看模型
notebook.start("--logdir ./data/keras_model")
或者我们自己绘图:首先我们构造数据
import pandas as pd
dfhistory = pd.DataFrame(history.history)
dfhistory.index = range(1,len(dfhistory) + 1)
dfhistory.index.name = 'epoch'
dfhistory
然后绘制:
%matplotlib inline
%config InlineBackend.figure_format = 'svg'
import matplotlib.pyplot as plt
def plot_metric(history, metric):
train_metrics = history.history[metric]
val_metrics = history.history['val_'+metric]
epochs = range(1, len(train_metrics) + 1)
plt.plot(epochs, train_metrics, 'bo--')
plt.plot(epochs, val_metrics, 'ro-')
plt.title('Training and validation '+ metric)
plt.xlabel("Epochs")
plt.ylabel(metric)
plt.legend(["train_"+metric, 'val_'+metric])
plt.show()
plot_metric(history,"loss")
plot_metric(history,"accuracy")
评估模型:
# 可以使用evaluate对数据进行评估
val_loss,val_accuracy = model.evaluate(ds_test,workers=4)
print(val_loss,val_accuracy)
20/20 [==============================] - 2s 80ms/step - loss: 0.1595 - accuracy: 0.9415
0.15954092144966125 0.9415000081062317
5、使用模型
可以使用model.predict(ds_test)进行预测。
也可以使用model.predict_on_batch(x_test)对一个批量进行预测。
model.predict(ds_test)
array([[1.1052408e-01],
[3.4282297e-02],
[2.7046111e-04],
...,
[2.7544077e-03],
[3.4654222e-04],
[9.9993896e-01]], dtype=float32)
for x,y in ds_test.take(1):
print(model.predict_on_batch(x[0:20]))
[[9.8728174e-01]
[2.0267103e-02]
[9.0806475e-03]
[9.9996555e-01]
[4.5376007e-02]
[1.2818890e-03]
[1.8698535e-03]
[2.2900696e-03]
[8.6169255e-01]
[6.2768459e-06]
[1.2383183e-02]
[4.3949869e-02]
[7.9778886e-01]
[9.9822074e-01]
[9.9993134e-01]
[8.6685091e-02]
[3.7480664e-02]
[9.9652690e-01]
[9.2210865e-01]
[1.6160560e-03]]
6、保存模型
推荐使用TensorFlow原生方式保存模型。
# 保存权重,该方式仅仅保存权重张量
model.save_weights('./data/tf_model_weights.ckpt',save_format = "tf")
# 保存模型结构与模型参数到文件,该方式保存的模型具有跨平台性便于部署
model.save('./data/tf_model_savedmodel', save_format="tf")
print('export saved model.')
model_loaded = tf.keras.models.load_model('./data/tf_model_savedmodel')
model_loaded.evaluate(ds_test)
参考:
开源电子书地址:https://lyhue1991.github.io/eat_tensorflow2_in_30_days/
GitHub 项目地址:https://github.com/lyhue1991/eat_tensorflow2_in_30_days
- struts2: 玩转 rest-plugin
- 设置系统环境变量立即生效的VBS脚本
- velocity模板引擎学习(1)
- mybatis 3.x 缓存Cache的使用
- XStream、JAXB 日期(Date)、数字(Number)格式化输出xml
- mac: vmware fusion中cent os启动假死的解决办法
- java:hibernate + oracle之坑爹的clob
- 启用WCF NetTcpBinding的共享端口
- asp中的md5/sha1/sha256算法收集
- UE4从零搭建CF游戏关卡(蓝图篇)
- 通用的序列号生成器库
- 利用Geneva开发SOA的安全模型
- STOMP协议介绍
- ADO.NET实体框架连接串引发的异常:Unable to load the specified metadata resource
- JavaScript 教程
- JavaScript 编辑工具
- JavaScript 与HTML
- JavaScript 与Java
- JavaScript 数据结构
- JavaScript 基本数据类型
- JavaScript 特殊数据类型
- JavaScript 运算符
- JavaScript typeof 运算符
- JavaScript 表达式
- JavaScript 类型转换
- JavaScript 基本语法
- JavaScript 注释
- Javascript 基本处理流程
- Javascript 选择结构
- Javascript if 语句
- Javascript if 语句的嵌套
- Javascript switch 语句
- Javascript 循环结构
- Javascript 循环结构实例
- Javascript 跳转语句
- Javascript 控制语句总结
- Javascript 函数介绍
- Javascript 函数的定义
- Javascript 函数调用
- Javascript 几种特殊的函数
- JavaScript 内置函数简介
- Javascript eval() 函数
- Javascript isFinite() 函数
- Javascript isNaN() 函数
- parseInt() 与 parseFloat()
- escape() 与 unescape()
- Javascript 字符串介绍
- Javascript length属性
- javascript 字符串函数
- Javascript 日期对象简介
- Javascript 日期对象用途
- Date 对象属性和方法
- Javascript 数组是什么
- Javascript 创建数组
- Javascript 数组赋值与取值
- Javascript 数组属性和方法
- 突击并发编程JUC系列-阻塞队列 BlockingQueue
- Matlab建立SVM,KNN和朴素贝叶斯模型分类绘制ROC曲线
- Python安装TensorFlow 2、tf.keras和深度学习模型的定义
- TensorFlow2 keras深度学习:MLP,CNN,RNN
- Flutter ListView 局部刷新数据、ListView点赞收藏
- R语言公交地铁路线网络图实现数据挖掘实战
- R语言风险价值VaR(Value at Risk)和损失期望值ES(Expected shortfall)的估计
- R语言机器学习实战之多项式回归
- 5000字!带你零距离接触websocket!
- 使用 GitLab CI 和 Docker 自动部署 Spring Boot 应用
- 玩转StyleGAN2模型:教你生成动漫人物
- R语言时间序列数据指数平滑法分析交互式动态可视化
- 再见Excel!最强国产开源在线表格Luckysheet走红GitHub
- R语言广义线性模型索赔频率预测:过度分散、风险暴露数和树状图可视化
- R语言多分类logistic逻辑回归模型在混合分布模拟单个风险损失值评估的应用