[Keras深度学习浅尝]实战二·CNN实现Fashion MNIST 数据集分类
时间:2022-06-24
本文章向大家介绍[Keras深度学习浅尝]实战二·CNN实现Fashion MNIST 数据集分类,主要内容包括其使用实例、应用技巧、基本知识点总结和需要注意事项,具有一定的参考价值,需要的朋友可以参考一下。
[Keras深度学习浅尝]实战二·CNN实现Fashion MNIST 数据集分类
与我们上篇博文[Keras深度学习浅尝]实战一结构相同,修改的地方有,定义网络与模型训练两部分,可以对比着来看。通过使用CNN结构,预测准确率略有提升,可以通过修改超参数以获得更优结果。 代码部分
# TensorFlow and tf.keras
import tensorflow as tf
from tensorflow import keras
# Helper libraries
import os
os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"
import numpy as np
import matplotlib.pyplot as plt
EAGER = True
fashion_mnist = keras.datasets.fashion_mnist
(train_images, train_labels), (test_images, test_labels) = fashion_mnist.load_data()
print(train_images.shape,train_labels.shape)
train_images = train_images.reshape([-1,28,28,1]) / 255.0
test_images = test_images.reshape([-1,28,28,1]) / 255.0
model = keras.Sequential([
#(-1,28,28,1)->(-1,28,28,32)
keras.layers.Conv2D(input_shape=(28, 28, 1),filters=32,kernel_size=5,strides=1,padding='same'), # Padding method),
#(-1,28,28,32)->(-1,14,14,32)
keras.layers.MaxPool2D(pool_size=2,strides=2,padding='same'),
#(-1,14,14,32)->(-1,14,14,64)
keras.layers.Conv2D(filters=64,kernel_size=3,strides=1,padding='same'), # Padding method),
#(-1,14,14,64)->(-1,7,7,64)
keras.layers.MaxPool2D(pool_size=2,strides=2,padding='same'),
#(-1,7,7,64)->(-1,7*7*64)
keras.layers.Flatten(),
#(-1,7*7*64)->(-1,256)
keras.layers.Dense(256, activation=tf.nn.relu),
#(-1,256)->(-1,10)
keras.layers.Dense(10, activation=tf.nn.softmax)
])
print(model.summary())
lr = 0.001
epochs = 5
model.compile(optimizer=tf.train.AdamOptimizer(lr),
loss='sparse_categorical_crossentropy',
metrics=['accuracy'])
model.fit(train_images, train_labels, batch_size = 200, epochs=epochs,validation_data=[test_images[:1000],test_labels[:1000]])
test_loss, test_acc = model.evaluate(test_images, test_labels)
print(np.argmax(model.predict(test_images[:10]),1),test_labels[:10])
输出结果
(60000, 28, 28) (60000,)
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
conv2d (Conv2D) (None, 28, 28, 32) 832
_________________________________________________________________
max_pooling2d (MaxPooling2D) (None, 14, 14, 32) 0
_________________________________________________________________
conv2d_1 (Conv2D) (None, 14, 14, 64) 18496
_________________________________________________________________
max_pooling2d_1 (MaxPooling2 (None, 7, 7, 64) 0
_________________________________________________________________
flatten (Flatten) (None, 3136) 0
_________________________________________________________________
dense (Dense) (None, 256) 803072
_________________________________________________________________
dense_1 (Dense) (None, 10) 2570
=================================================================
Total params: 824,970
Trainable params: 824,970
Non-trainable params: 0
_________________________________________________________________
None
Train on 60000 samples, validate on 1000 samples
Epoch 1/5
60000/60000 [==============================] - 64s 1ms/step - loss: 0.3806 - acc: 0.8619 - val_loss: 0.2797 - val_acc: 0.9010
Epoch 2/5
60000/60000 [==============================] - 63s 1ms/step - loss: 0.2495 - acc: 0.9090 - val_loss: 0.2647 - val_acc: 0.9000
Epoch 3/5
60000/60000 [==============================] - 63s 1ms/step - loss: 0.1987 - acc: 0.9255 - val_loss: 0.2725 - val_acc: 0.9000
Epoch 4/5
60000/60000 [==============================] - 63s 1ms/step - loss: 0.1630 - acc: 0.9388 - val_loss: 0.2852 - val_acc: 0.9010
Epoch 5/5
60000/60000 [==============================] - 63s 1ms/step - loss: 0.1314 - acc: 0.9514 - val_loss: 0.2704 - val_acc: 0.9140
[9 2 1 1 6 1 4 6 5 7] [9 2 1 1 6 1 4 6 5 7]
- CentOS mysql配置主从复制
- Quartz依赖数据库表
- Spring Security Oauth2.0 实现短信验证码登录
- 【Spring Cloud】Redis缓存接入监控、运维平台CacheCloud
- 基于Redis实现分布式应用限流
- Jasypt : 整合spring boot加密应用配置文件敏感信息
- Eureka:扩展ClientFilter实现服务注册自定义过滤
- 【系统日志】log4j配置学习总结
- 【译】MySQL char、varchar的区别
- 【jfinal修仙系列】修改ShiroPlugin支持jfinal3.0
- MySQL二进制日志
- 【nginx启动】 97 Address family not supported by protocol
- jfinal 内置的handler功能
- JS 对指定iframe 全屏操作
- JavaScript 教程
- JavaScript 编辑工具
- JavaScript 与HTML
- JavaScript 与Java
- JavaScript 数据结构
- JavaScript 基本数据类型
- JavaScript 特殊数据类型
- JavaScript 运算符
- JavaScript typeof 运算符
- JavaScript 表达式
- JavaScript 类型转换
- JavaScript 基本语法
- JavaScript 注释
- Javascript 基本处理流程
- Javascript 选择结构
- Javascript if 语句
- Javascript if 语句的嵌套
- Javascript switch 语句
- Javascript 循环结构
- Javascript 循环结构实例
- Javascript 跳转语句
- Javascript 控制语句总结
- Javascript 函数介绍
- Javascript 函数的定义
- Javascript 函数调用
- Javascript 几种特殊的函数
- JavaScript 内置函数简介
- Javascript eval() 函数
- Javascript isFinite() 函数
- Javascript isNaN() 函数
- parseInt() 与 parseFloat()
- escape() 与 unescape()
- Javascript 字符串介绍
- Javascript length属性
- javascript 字符串函数
- Javascript 日期对象简介
- Javascript 日期对象用途
- Date 对象属性和方法
- Javascript 数组是什么
- Javascript 创建数组
- Javascript 数组赋值与取值
- Javascript 数组属性和方法
- django ObjectDoesNotExist 和 DoesNotExist的用法
- PHP PDOStatement::closeCursor讲解
- 使用python实现下载我们想听的歌曲,速度超快
- JS(jQuery)实现聊天接收到消息语言自动提醒功能详解【提示“您有新的消息请注意查收”】
- OpenCV4.1.0+VS2017环境配置的方法步骤
- 详解如何实现Laravel的服务容器的方法示例
- laravel 数据迁移与 Eloquent ORM的实现方法
- PDO::query讲解
- Laravel5框架自定义错误页面配置操作示例
- PHP-FPM和Nginx的通信机制详解
- PHP PDOStatement::columnCount讲解
- PHP中上传文件打印错误错误类型分析
- Laravel如何创建服务器提供者实例代码
- PDO::quote讲解
- php intval函数用法总结