小白学PyTorch | 15 TF2实现一个简单的服装分类任务
【机器学习炼丹术】的学习笔记分享
参考目录:
- 0 为什么学TF
- 1 Tensorflow的安装
- 2 数据集构建
- 2 预处理
- 3 构建模型
- 4 优化器
- 5 训练与预测
0 为什么学TF
之前的15节课的pytorch的学习,应该是让不少朋友对PyTorch有了一个全面而深刻的认识了吧 (如果你认真跑代码了并且认真看文章了的话) 。
大家都会比较Tensorflow2和pytorch之间孰优孰劣,但是我们也并不是非要二者选一,两者都是深度学习的工具,其实我们或多或少应该了解一些比较好。就好比,PyTorch是冲锋枪,TensorFlow是步枪,在上战场前,我们可以选择带上冲锋枪还是步枪,但是在战场上,可能手中的枪支没有子弹了,你只能在地上随便捡了一把枪。很多时候,用Pytorch还是Tensorflow的选择权不在自己。
此外,了解了TensorFlow,大家才能更好的理解PyTorch和TF究竟有什么区别。我见过有的大佬是TF和PyTorch一起用在一个项目中,数据读取用PyTorch然后模型用TF构建。
总之,大家有时间有精力的话,顺便学学TF也不亏,更何况TF2.0现在已经优化了很多。本系列预计用3节课来简单的入门一下Tensorflow2.
和PyTorch的第一课一样,我们直接做一个简单的小实战。MNIST手写数字分类,Fashion MNIST时尚服装分类。
1 Tensorflow的安装
安装TensorFlow的方法很简单,就是在控制台执行:
pip install tensorflow --user
这里的--user
是赋予这个命令执行权限的,一般我都会带上。
2 数据集构建
# keras是TF的高级API,用起来更加的方便,一般也是用keras。
import tensorflow as tf
from tensorflow import keras
import numpy as np
导入需要用到的库函数. 正如torchvision.datasets
中一样,keras.datasets
中也封装了一些常用的数据集。
fashion_mnist = keras.datasets.fashion_mnist
(train_images, train_labels), (test_images, test_labels) = fashion_mnist.load_data()
print('train_images shape:',train_images.shape)
print('train_labels shape:',train_labels.shape)
print('test_images shape:',test_images.shape)
print('test_labels shape:',test_labels.shape)
输出结果是:
训练数据集中有60000个样本,每一个样本和MNIST手写数字大小是一样的,是
大小的,然后每一个样本有一个标签,这个标签和MNIST也是一样的,是从0到9,是一个十分类任务。
来看一下这些类别有哪些:
标签 |
类别 |
标签 |
类别 |
---|---|---|---|
0 |
T-shirt |
5 |
Sandal |
1 |
Trouser |
6 |
Shirt |
2 |
Pullover |
7 |
Sneaker |
3 |
Dress |
8 |
Bag |
4 |
Coat |
9 |
Ankle boot |
这里学学单词吧:
- T-shirt就是T型的衬衫,就是短袖,我感觉前面没有扣子的那种也叫T-shirt;
- Shirt就是长袖的那种衬衫;
- Trouser是裤子;
- pullover是毛衣,套头毛衣,就是常说的卫衣吧感觉;
- dress连衣裙;
- coat是外套;
- sandal是凉鞋;
- sneaker是运动鞋;
- ankle boot是短靴,是到脚踝的那种靴子;
- 这里补充一个吧,sweater,是毛线衣,运动衫,这个和pullover有些类似,个人感觉主要的区分在于运动系列的可以叫做sweater,其他的毛衣卫衣是pullover。
运动短袖T-shirt+运动卫衣sweater是我秋天去健身房的穿搭。
2 预处理
这里不做图像增强之类的了,上面的数据中,图像像素值是从0到255的,我们要把这些标准化成0到1的范围。
train_images = train_images / 255.0
test_images = test_images / 255.0
3 构建模型
# 模型搭建
model = keras.Sequential([
keras.layers.Flatten(input_shape=(28, 28)),
keras.layers.Dense(128, activation='relu'),
keras.layers.Dense(10, activation='softmax')
])
这就是一个用keras构建简单模型的例子:
-
keras.layers.Flatten
是把
的二维度拉平成一个维度,因为这里是直接用全连接层而不是卷积层进行处理的;
- 后面跟上两个全连接层
keras.layers.Dense()
就行了。我们可以发现,这个全连接层的参数和PyTorch是有一些区别的:- PyTorch的全连接层需要一个输入神经元数量和输出数量
torch.nn.Linear(5,10)
,而keras中的Dense是不需要输入参数的keras.layers.Dense(10)
; - keras中的激活层直接封装在了Dense函数里面,所以不需要像PyTorch一样单独写一个
nn.ReLU()
了。
- PyTorch的全连接层需要一个输入神经元数量和输出数量
4 优化器
model.compile(optimizer='adam',
loss='sparse_categorical_crossentropy',
metrics=['accuracy'])
定义优化器和损失函数,在keras中叫做对模型进行编译compile(在C语言中,在运行代码之前都需要对代码进行编译嘛)。损失函数和优化器还有metric衡量指标的设置都在模型的编译函数中设置完成。
上面使用Adam作为优化器,然后损失函数用了交叉熵,然后衡量模型性能的使用了准确率Accuracy。
5 训练与预测
model.fit(train_images, train_labels, epochs=10)
这就是训练过程,相比PyTorch而言,更加的简单简洁,但是不像PyTorch那样灵活。
test_loss, test_acc = model.evaluate(test_images, test_labels)
print('nTest accuracy:', test_acc)
这个.evaluate
方法是对模型的验证集进行验证的,因为本次任务中并没有对训练数据再划分出验证集,所以这里直接使用测试数据了。
大家应该能理解训练集、验证集和测试集的用途和区别吧,我在第二课讲过这个内容,在此不多加赘述。
predictions = model.predict(test_images)
这个.predict
方法才是用在测试集上,进行未知标签样本的类别推理的。
本次内容到此为止,大家应该对keras和tensorflow有一个直观浅显的认识了。当然tensorflow也有一套类似于PyTorch中的dataset,dataloader的那样自定义的数据集加载器的方法,在后续内容中会深入浅出的学一下。
- END -
- Java 数据类型转换
- Spring boot with Scheduling
- Spring Properties 文件读取
- 【学术】你真的知道什么是随机森林吗?本文是关于随机森林的直观解读
- Spring boot 将 Session 放入 Redis
- 【教程】估算一个最佳学习速率,以更好地训练深度神经网络
- SNS 数据库设计
- CentOS7 下 MySQL 5.7 重置root密码
- 通过简单的线性回归理解机器学习的基本原理
- 消息队列在使用中的注意事项
- 【教程】OpenCV—Node.js教程系列:用Tensorflow和Caffe“做游戏”
- 验证码,再见!利用机器学习在15分钟内破解验证码
- Spring boot with Redis
- SOA 面向服务框架设计与实现
- JavaScript 教程
- JavaScript 编辑工具
- JavaScript 与HTML
- JavaScript 与Java
- JavaScript 数据结构
- JavaScript 基本数据类型
- JavaScript 特殊数据类型
- JavaScript 运算符
- JavaScript typeof 运算符
- JavaScript 表达式
- JavaScript 类型转换
- JavaScript 基本语法
- JavaScript 注释
- Javascript 基本处理流程
- Javascript 选择结构
- Javascript if 语句
- Javascript if 语句的嵌套
- Javascript switch 语句
- Javascript 循环结构
- Javascript 循环结构实例
- Javascript 跳转语句
- Javascript 控制语句总结
- Javascript 函数介绍
- Javascript 函数的定义
- Javascript 函数调用
- Javascript 几种特殊的函数
- JavaScript 内置函数简介
- Javascript eval() 函数
- Javascript isFinite() 函数
- Javascript isNaN() 函数
- parseInt() 与 parseFloat()
- escape() 与 unescape()
- Javascript 字符串介绍
- Javascript length属性
- javascript 字符串函数
- Javascript 日期对象简介
- Javascript 日期对象用途
- Date 对象属性和方法
- Javascript 数组是什么
- Javascript 创建数组
- Javascript 数组赋值与取值
- Javascript 数组属性和方法
- 使用ML 和 DNN 建模的技巧总结
- 医学图像分割模型U-Net介绍和Kaggle的Top1解决方案源码解析
- 机器学习中的音频特征:理解Mel频谱图
- 兄弟,如何淡定地渡过七夕?
- Spring 源码第 9 篇,深入分析 FactoryBean
- PowerBI 动态数据格式 高级版 以及重要通知
- 气哭老板的顶级密钥存放方案,又做了一件蠢事
- 构建没有数据集的辣辣椒分类器,准确性达到96%
- 由 Redis 分布式锁造成的重大事故
- 10分钟搞定 Java 并发队列好吗?好的
- MySQL 案例:关于程序端的连接池与数据库的连接数
- spark和kafka jar包冲突NoSuchMethodError: net.jpountz.lz4.LZ4BlockInputStream
- 聊聊claudb的scripting command
- PHP怎么获取视频总时长的函数方法
- 构建Docker私有仓库