[PyTorch小试牛刀]实战六·准备自己的数据集用于训练(基于猫狗大战数据集)
时间:2022-06-24
本文章向大家介绍[PyTorch小试牛刀]实战六·准备自己的数据集用于训练(基于猫狗大战数据集),主要内容包括其使用实例、应用技巧、基本知识点总结和需要注意事项,具有一定的参考价值,需要的朋友可以参考一下。
[PyTorch小试牛刀]实战六·准备自己的数据集用于训练(基于猫狗大战数据集)
在上面几个实战中,我们使用的是Pytorch官方准备好的FashionMNIST数据集进行的训练与测试。本篇博文介绍我们如何自己去准备数据集,以应对更多的场景。
我们此次使用的是猫狗大战数据集,开始之前我们要先把数据处理一下,形式如下
datas │ └───train │ │ │ └───cats │ │ │ cat1000.jpg │ │ │ cat1001.jpg │ │ │ … │ └───dogs │ │ │ dog1000.jpg │ │ │ dog1001.jpg │ │ │ … └───valid │ │ │ └───cats │ │ │ cat0.jpg │ │ │ cat1.jpg │ │ │ … │ └───dogs │ │ │ dog0.jpg │ │ │ dog1.jpg │ │ │ …
train数据集中有23000张数据,valid数据集中有2000数据用于验证网络性能
代码部分 1.采用隐形字典形式,代码简练,不易理解
import torch as t
import torchvision as tv
import os
data_dir = "./datas"
BATCH_SIZE = 100
EPOCH = 10
transform = {
x:tv.transforms.Compose(
[tv.transforms.Resize([64,64]),tv.transforms.ToTensor()]#tv.transforms.Resize 用于重设图片大小
)
for x in ["train","valid"]
}
datasets = {
x:tv.datasets.ImageFolder(root = os.path.join(data_dir,x),transform=transform[x])
for x in ["train","valid"]
}
dataloader = {
x:t.utils.data.DataLoader(dataset= datasets[x],
batch_size=BATCH_SIZE,
shuffle=True
)
for x in ["train","valid"]
}
b_x,b_y = next(iter(dataloader["train"]))
print(b_x.shape,b_y.shape)
index_classes = datasets["train"].class_to_idx
print(index_classes)
2.采用显性字典形式,代码稍多,易于理解
import torch as t
import torchvision as tv
data_dir = "./datas"
BATCH_SIZE = 100
EPOCH = 10
transform = {
"train":tv.transforms.Compose(
[tv.transforms.Resize([64,64]),tv.transforms.ToTensor()]
),
"valid":tv.transforms.Compose(
[tv.transforms.Resize([64,64]),tv.transforms.ToTensor()]
),
}
datasets = {
"train":tv.datasets.ImageFolder(root = os.path.join(data_dir,"train"),transform=transform["train"]),
"vaild":tv.datasets.ImageFolder(root = os.path.join(data_dir,"vaild"),transform=transform["vaild"]),
}
dataloader = {
"train":t.utils.data.DataLoader(dataset= datasets["train"],
batch_size=BATCH_SIZE,
shuffle=True
),
"valid":t.utils.data.DataLoader(dataset= datasets["valid"],
batch_size=100,
shuffle=True
)
}
b_x,b_y = next(iter(dataloader["train"]))
print(b_x.shape,b_y.shape)
index_classes = datasets["train"].class_to_idx
print(index_classes)
输出结果
torch.Size([100, 3, 64, 64]) torch.Size([100])
{'cats': 0, 'dogs': 1}
- 通过Nethogs查看服务器网卡流量情况
- 美国国会关于人工智能的提案
- update的多表更新的试验
- silverlight中如何得到ComboBox的选中值(SelectedValue)?
- kvm虚拟化管理平台WebVirtMgr部署-完整记录(安装ubuntu虚拟机)-(5)
- 从MapX到MapXtreme2004[9]-标注的强调显示
- 【第一季】Vue2.0内部指令
- 从MapX到MapXtreme2004[9]-标注的强调显示
- 分布式监控系统Zabbix-3.0.3-完整安装记录(4)-解决zabbix监控图中出现中文乱码问题
- 常用Lambda表达式实例
- centos6.8部署vnc服务
- linux下的缓存机制及清理buffer/cache/swap的方法梳理
- 分组合计且排序和显示名称
- silverlight动态读取txt文件/解析json数据/调用wcf示例
- JavaScript 教程
- JavaScript 编辑工具
- JavaScript 与HTML
- JavaScript 与Java
- JavaScript 数据结构
- JavaScript 基本数据类型
- JavaScript 特殊数据类型
- JavaScript 运算符
- JavaScript typeof 运算符
- JavaScript 表达式
- JavaScript 类型转换
- JavaScript 基本语法
- JavaScript 注释
- Javascript 基本处理流程
- Javascript 选择结构
- Javascript if 语句
- Javascript if 语句的嵌套
- Javascript switch 语句
- Javascript 循环结构
- Javascript 循环结构实例
- Javascript 跳转语句
- Javascript 控制语句总结
- Javascript 函数介绍
- Javascript 函数的定义
- Javascript 函数调用
- Javascript 几种特殊的函数
- JavaScript 内置函数简介
- Javascript eval() 函数
- Javascript isFinite() 函数
- Javascript isNaN() 函数
- parseInt() 与 parseFloat()
- escape() 与 unescape()
- Javascript 字符串介绍
- Javascript length属性
- javascript 字符串函数
- Javascript 日期对象简介
- Javascript 日期对象用途
- Date 对象属性和方法
- Javascript 数组是什么
- Javascript 创建数组
- Javascript 数组赋值与取值
- Javascript 数组属性和方法
- Elasticsearch:inverted index,doc_values及source
- 在群晖docker上构建私有云IDE和devops构建链
- 小白学PyTorch | 14 tensorboardX可视化教程
- Apache Solr 漏洞复现
- Elasticsearch rollover API
- 重发和重定向有什么区别与重定向应用
- 为tinycolinux制作应用包
- CrossC2的2.0版本
- 使用OpenCV和Python计算图像的“色彩”
- 为tinycolinux创建应用包-toolchain和编译方法
- [译]在Solidity中如何优化Gas第一部分:变量
- [译]Solidity 0.7.0 新变化
- 两个数组的交集 II
- 常说的手机刷新率60Hz、120Hz有什么不同?
- Istio 运维实战系列(3):让人头大的『无头服务』-下