【动手学深度学习笔记】之图像分类数据集(Fashion-MNIST)
1.图像分类数据集(Fashion-MNIST)
这一章节需要用到torchvision包,为此,我重装了
这个数据集是我们在后面学习中将会用到的图形分类数据集。它的图像内容相较于手写数字识别数据集MINIST更为复杂一些,更加便于我们直观的观察算法之间的差异。
这一节主要使用torchvision包,主要用来构建计算机视觉模型。
torchvision包的主要构成 |
功能 |
---|---|
torchvision.datasets |
一些加载数据的函数及常用数据集接口 |
torchvision.madels |
包含常用的模型结构(含预训练模型) |
torchvision.transforms |
常用的图片变换(裁剪、旋转) |
torchvision.utils |
其他方法 |
1.1获取数据集
首先导入需要的包
import torch
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import time
import sys
sys.path.append("..")
#调用库时,sys.path会自动搜索路径,为了导入d2l这个库,所以需要添加".."
#import d2lzh_pytorch as d2l 这个库找不到不用了
from IPython import display
#在这一节d2l库仅仅在绘图时被使用,因此使用这个库做替代
**通过调用torchvision中的torchvision.datasets来下载这个数据集。**第一次调用从网上自动获取数据。
通过设置参数train来制定获取训练数据集或测试数据集(测试集:用来评估模型表现,并不用来训练模型)。
通过设置参数transfrom = transforms.ToTensor()将所有数据转换成Tensor,如果不进行转换则返回PIL图片。
transforms.ToTensor()函数将尺寸为(H*W*C)且数据位于[0,255]之间的PIL图片或者数据类型为np.uint8的NumPy数组转换为尺寸为(C*H*W)且数据类型为torch.float32且位于[0,0,1.0]的Tensor C代表通道数,灰度图像的通道数为1 PIL图片是python处理图片的标准 注意:transforms.ToTensor()函数默认将输入类型设置为uint8
#获取训练集
mnist_train = torchvision.datasets.FashionMNIST(root='~/Datasets/FashionMNIST',train=True,download = True,transform = transforms.ToTensor())
#获取测试集
mnist_test = torchvision.datasets.FashionMNIST(root='~/Datasets/FashionMNIST',train=True,download = True,transform = transforms.ToTensor())
其中mnist_train和mnist_test可以用len()来获取该数据集的大小,还可以用下标来获取具体的一个样本。
训练集和测试集都有10个类别,训练集中每个类别的图像数为6000,测试集中每个类别的图像数为1000,即:训练集中有60000个样本,测试集中有10000个样本。
len(mnist_train) #输出训练集的样本数
mnist_train[0] #通过下标访问任意一个样本,返回值为两个torch,一个特征tensor和一个标签tensor
Fashion-MNIST数据集中共有十个类别,分别为:t-shirt(T恤)、trouser(裤子)、pullover(套衫)、dress(连衣裙)、coat(外套)、sandal(凉鞋)、shirt(衬衫)、sneaker(运动鞋)、bag(包)和ankle boot(短靴) 。
需要将这些文本标签和数值标签相互转换,可以通过以下函数进行。
def get_fashion_mnist_labels(labels):
text_labels = ['t-shirt', 'trouser', 'pullover', 'dress', 'coat',
'sandal', 'shirt', 'sneaker', 'bag', 'ankle boot']
return [text_labels[int(i)] for i in labels]
#labels是一个列表
#数值标签转文本标签
下面是一个可以在一行里画出多张图像和对应标签的函数
def show_fashion_mnist(images, labels):
d2l.use_svg_display()
#绘制矢量图
_, figs = plt.subplots(1, len(images), figsize=(12, 12))
#创建子图,一行len(images)列,图片大小12*12
for f, img, lbl in zip(figs, images, labels):
#zip函数将他们压缩成由多个元组组成的列表
f.imshow(img.view((28, 28)).numpy())
#将img转形为28*28大小的张量,然后转换成numpy数组
f.set_title(lbl)
#设置每个子图的标题为标签
f.axes.get_xaxis().set_visible(False)
f.axes.get_yaxis().set_visible(False)
#关闭x轴y轴
plt.show()
上述函数的使用
X,y = [],[]
#初始化两个列表
for i in range(10):
X.append(mnist_train[i][0])
#循环向X列表添加图像
y.append(mnist_train[i][1])
#循环向y列表添加标签
show_fashion_mnist(X,get_fashion_mnist_labels(y))
#显示图像和列表
1.2在模型中读取小批量
有了线性回归中读取小批量的经验,我们知道读取小批量可以使用torch中内置的dataloader函数来实现。
dataloader还支持多线程读取数据,通过设置它的num_workers参数。
batch_size = 256
#小批量数目
train_iter = torch.utils.data.DataLoader(mnist_train,batch_size=batch_size,shuffle = True,num_workers = 0)
#num_workers=0,不开启多线程读取。
test_iter = torch.utils.data.DataLoader(mnist_test,batch_size = batch_size,shuffle=False,num_workers=0)
1.3获取并显示样本程序
import torch
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import time
import sys
sys.path.append("..")
#由于没有d2l库,所以需要进行一些修改
#修改一:
from IPython import display
mnist_train = torchvision.datasets.FashionMNIST(root='~/Datasets/FashionMNIST',train=True,download=True,transform=transforms.ToTensor())
#获取训练集
mnist_test = torchvision.datasets.FashionMNIST(root='~/Datasets/FashionMNIST',train=False,download=True,transform=transforms.ToTensor())
#获取测试集
def get_Fashion_MNIST_labels(labels):
text_labels = ['t-shirt', 'trouser', 'pullover', 'dress', 'coat',
'sandal', 'shirt', 'sneaker', 'bag', 'ankle boot']
return [text_labels[int(i)] for i in labels]
#labels是一个列表,所以有了for循环获取这个列表对应的文本列表
def show_fashion_mnist(images,labels):
display.set_matplotlib_formats('svg')
#绘制矢量图
_,figs = plt.subplots(1,len(images),figsize=(12,12))
#设置添加子图的数量、大小
for f,img,lbl in zip(figs,images,labels):
f.imshow(img.view(28,28).numpy())
f.set_title(lbl)
f.axes.get_xaxis().set_visible(False)
f.axes.get_yaxis().set_visible(False)
plt.show()
#显示部分
X,y =[],[]
#初始化两个列表
for i in range(10):
#先从训练集中抽取出十个
X.append(mnist_train[i][0])
#添加图像
y.append(mnist_train[i][1])
#添加标签
show_fashion_mnist(X,get_Fashion_MNIST_labels(y))
- spring cloud 学习(1) - 基本的SOA示例
- SVN冲突
- 什么叫微信小程序分销系统?如何通过分销系统来实现你的创业梦
- Hadoop(十一)Hadoop IO之序列化与比较功能实现详解
- 安卓第五夜 维纳斯的诞生
- Eclipse中Project的Deployment Assembly(部署程序集)消失了
- spring-boot 速成(9) druid+mybatis 多数据源及读写分离的处理
- Python标准库14 数据库 (sqlite3)
- spring cloud 学习(4) - hystrix 服务熔断处理
- Hadoop(十)Hadoop IO之数据完整性
- Tomcat 端口号修改
- Mac OSX网络诊断命令
- spring cloud 学习(5) - config server
- Java魔法堂:解读基于Type Erasure的泛型
- JavaScript 教程
- JavaScript 编辑工具
- JavaScript 与HTML
- JavaScript 与Java
- JavaScript 数据结构
- JavaScript 基本数据类型
- JavaScript 特殊数据类型
- JavaScript 运算符
- JavaScript typeof 运算符
- JavaScript 表达式
- JavaScript 类型转换
- JavaScript 基本语法
- JavaScript 注释
- Javascript 基本处理流程
- Javascript 选择结构
- Javascript if 语句
- Javascript if 语句的嵌套
- Javascript switch 语句
- Javascript 循环结构
- Javascript 循环结构实例
- Javascript 跳转语句
- Javascript 控制语句总结
- Javascript 函数介绍
- Javascript 函数的定义
- Javascript 函数调用
- Javascript 几种特殊的函数
- JavaScript 内置函数简介
- Javascript eval() 函数
- Javascript isFinite() 函数
- Javascript isNaN() 函数
- parseInt() 与 parseFloat()
- escape() 与 unescape()
- Javascript 字符串介绍
- Javascript length属性
- javascript 字符串函数
- Javascript 日期对象简介
- Javascript 日期对象用途
- Date 对象属性和方法
- Javascript 数组是什么
- Javascript 创建数组
- Javascript 数组赋值与取值
- Javascript 数组属性和方法
- java_数据类型转换、运算符
- JavaScript中字符串运算符是什么?有哪些?
- Docker日常使用方式
- 使用Python爬取动态网页-豆瓣电影(JSON)
- Linux安装配置PHPmyadmin
- Angular Service依赖注入的一个具体例子
- php学习day1
- 在Angular里使用rxjs的异步API - Observable
- 自动化监控Oracle表空间并发送报警
- Angular里的消息(Message)显示
- Angular应用内路由(In App Route)的最佳实践
- Angular应用的路由指令RouterLink
- ctfhub-信息泄泄露_备份文件下载
- Angular In-memory Web API使用介绍
- 攻防世界-php_rce