PyTorch实现重写/改写Dataset并载入Dataloader
时间:2022-07-27
本文章向大家介绍PyTorch实现重写/改写Dataset并载入Dataloader,主要内容包括其使用实例、应用技巧、基本知识点总结和需要注意事项,具有一定的参考价值,需要的朋友可以参考一下。
前言
众所周知,Dataset和Dataloder是pytorch中进行数据载入的部件。必须将数据载入后,再进行深度学习模型的训练。在pytorch的一些案例教学中,常使用torchvision.datasets
自带的MNIST、CIFAR-10数据集,一般流程为:
# 下载并存放数据集
train_dataset = torchvision.datasets.CIFAR10(root="数据集存放位置",download=True)
# load数据
train_loader = torch.utils.data.DataLoader(dataset=train_dataset)
但是,在我们自己的模型训练中,需要使用非官方自制的数据集。这时应该怎么办呢?
我们可以通过改写torch.utils.data.Dataset
中的__getitem__
和__len__
来载入我们自己的数据集。
__getitem__
获取数据集中的数据,__len__
获取整个数据集的长度(即个数)。
改写
采用pytorch官网案例中提供的一个脸部landmark数据集。数据集中含有存放landmark的csv文件,但是我们在这篇文章中不使用(其实也可以随便下载一些图片作数据集来实验)。
import os
import torch
from skimage import io, transform
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils
plt.ion() # interactive mode
torch.utils.data.Dataset
是一个抽象类,我们自己的数据集需要继承Dataset
,然后改写上述两个函数:
class ImageLoader(Dataset):
def __init__(self, file_path, transform=None):
super(ImageLoader,self).__init__()
self.file_path = file_path
self.transform = transform # 对输入图像进行预处理,这里并没有做,预设为None
self.image_names = os.listdir(self.file_path) # 文件名的列表
def __getitem__(self,idx):
image = self.image_names[idx]
image = io.imread(os.path.join(self.file_path,image))
# if self.transform:
# image= self.transform(image)
return image
def __len__(self):
return len(self.image_names)
# 设置自己存放的数据集位置,并plot展示
imageloader = ImageLoader(file_path="D:Projectsdatasetsfaces")
# imageloader.__len__() # 输出数据集长度(个数),应为71
# print(imageloader.__getitem__(0)) # 以数据形式展示
plt.imshow(imageloader.__getitem__(0)) # 以图像形式展示
plt.show()
得到的图片输出:
得到的数据输出,:
array([[[ 66, 59, 53],
[ 66, 59, 53],
[ 66, 59, 53],
...,
[ 59, 54, 48],
[ 59, 54, 48],
[ 59, 54, 48]],
...,
[153, 141, 129],
[158, 146, 134],
[158, 146, 134]]], dtype=uint8)
上面看到dytpe=uint8
,实际进行训练的时候,常常需要更改成float
的数据类型。可以使用:
# 直接改成pytorch中的tensor下的float格式
# 也可以用numpy的改成普通的float格式
to_float= torch.from_numpy(imageloader.__getitem__(0)).float()
改写完成后,直接使用train_loader =torch.utils.data.DataLoader(dataset=imageloader)
载入到Dataloader
中,就可以使用了。
下面的代码可以试着运行一下,产生的是一模一样的图片结果。
train_loader = torch.utils.data.DataLoader(dataset=imageloader)
train_loader.dataset[0]
plt.imshow(train_loader.dataset[0])
plt.show()
- JavaScript 教程
- JavaScript 编辑工具
- JavaScript 与HTML
- JavaScript 与Java
- JavaScript 数据结构
- JavaScript 基本数据类型
- JavaScript 特殊数据类型
- JavaScript 运算符
- JavaScript typeof 运算符
- JavaScript 表达式
- JavaScript 类型转换
- JavaScript 基本语法
- JavaScript 注释
- Javascript 基本处理流程
- Javascript 选择结构
- Javascript if 语句
- Javascript if 语句的嵌套
- Javascript switch 语句
- Javascript 循环结构
- Javascript 循环结构实例
- Javascript 跳转语句
- Javascript 控制语句总结
- Javascript 函数介绍
- Javascript 函数的定义
- Javascript 函数调用
- Javascript 几种特殊的函数
- JavaScript 内置函数简介
- Javascript eval() 函数
- Javascript isFinite() 函数
- Javascript isNaN() 函数
- parseInt() 与 parseFloat()
- escape() 与 unescape()
- Javascript 字符串介绍
- Javascript length属性
- javascript 字符串函数
- Javascript 日期对象简介
- Javascript 日期对象用途
- Date 对象属性和方法
- Javascript 数组是什么
- Javascript 创建数组
- Javascript 数组赋值与取值
- Javascript 数组属性和方法
- mysql各种引擎对比、实战
- 接球小游戏玩腻了?换个姿势让PaddleX帮你吊打游戏系统
- mysql事务隔离级别详解和实战
- ELK+FileBeat+Kafka分布式系统搭建图文教程
- Flink CEP 原理和案例详解
- 实战开发,使用 Spring Session 与 Spring security 完成网站登录改造!!
- 分布式计算框架Gearman原理详解
- 【从0开始の全记录】Flume+Kafka+Spark+Spring Boot 统计网页访问量项目
- 系统级性能分析工具perf的介绍与使用[转]
- 深入理解排序算法
- 用nginx缓存静态文件
- 优雅的玩PHP多进程
- 聊一聊mycat数据库集群系列之双主双重实现
- Fast-SCNN的解释以及使用Tensorflow 2.0的实现
- 基于Spring Boot快速实现发送邮件功能