Pytorch 重写Dataloader
时间:2020-04-17
本文章向大家介绍Pytorch 重写Dataloader,主要包括Pytorch 重写Dataloader使用实例、应用技巧、基本知识点总结和需要注意事项,具有一定的参考价值,需要的朋友可以参考一下。
这是一个官网的例子:torch.nn入门。
一般而言,我们会根据自己的数据需求继承Dataset(from torch.utils.data import Dataset, DataLoader)重写数据读取函数。或者利用TensorDataset更加简洁实现读取数据。
抑或利用 torchvision里面的ImageFolder
也可管理数据。这几种方法已经可以实现数据读取了,而DataLoader的作用是更加全面管理批量数据:
下面进入正题,MNIST数据利用CNN时需要转换为二维数据,所以需要对初始的线性数据进行转换。一般,可以读取先行数据后在模型中进行view来实现:
class Lambda(nn.Module): def __init__(self, func): super().__init__() self.func = func def forward(self, x): return self.func(x) def preprocess(x): return x.view(-1, 1, 28, 28) model = nn.Sequential( Lambda(preprocess), nn.Conv2d(1, 16, kernel_size=3, stride=2, padding=1), nn.ReLU(), nn.Conv2d(16, 16, kernel_size=3, stride=2, padding=1), nn.ReLU(), nn.Conv2d(16, 10, kernel_size=3, stride=2, padding=1), nn.ReLU(), nn.AvgPool2d(4), Lambda(lambda x: x.view(x.size(0), -1)), )
文中给出另一种解决方案:重写DateLoader:将数据处理移到生成器里面
def get_data(train_ds, valid_ds, bs): return ( DataLoader(train_ds, batch_size=bs, shuffle=True), DataLoader(valid_ds, batch_size=bs * 2), ) def preprocess(x, y): return x.view(-1, 1, 28, 28), y class WrappedDataLoader: def __init__(self, dl, func): self.dl = dl self.func = func def __len__(self): return len(self.dl) def __iter__(self): batches = iter(self.dl) for b in batches: yield (self.func(*b)) train_dl, valid_dl = get_data(train_ds, valid_ds, bs) train_dl = WrappedDataLoader(train_dl, preprocess) valid_dl = WrappedDataLoader(valid_dl, preprocess)
模型就可以写成这样:
model = nn.Sequential( nn.Conv2d(1, 16, kernel_size=3, stride=2, padding=1), nn.ReLU(), nn.Conv2d(16, 16, kernel_size=3, stride=2, padding=1), nn.ReLU(), nn.Conv2d(16, 10, kernel_size=3, stride=2, padding=1), nn.ReLU(), nn.AdaptiveAvgPool2d(1), Lambda(lambda x: x.view(x.size(0), -1)), )
原文地址:https://www.cnblogs.com/king-lps/p/12721758.html
- Java基础巩固——反射
- 手把手教你在树莓派上搭建web服务器
- 安装和搭建基于netcore的demo
- 项目心得:广度遍历搜索部门处理业务
- 使用JAVA开发微信公众平台(一)——环境搭建与开发接入
- BZOJ4805: 欧拉函数求和(杜教筛)
- centos7.x下搭建netcore环境和helloworld的demo
- ARM coretex M4 系统定时器
- 用实例说明如何用JavaScript生成XML
- 51nod 1244 莫比乌斯函数之和(杜教筛)
- 几个比较有意思的JS脚本
- Java多线程高并发学习笔记——阻塞队列
- 使用javascript+xml实现分页
- 使用OAuth打造webapi认证服务供自己的客户端使用
- JavaScript 教程
- JavaScript 编辑工具
- JavaScript 与HTML
- JavaScript 与Java
- JavaScript 数据结构
- JavaScript 基本数据类型
- JavaScript 特殊数据类型
- JavaScript 运算符
- JavaScript typeof 运算符
- JavaScript 表达式
- JavaScript 类型转换
- JavaScript 基本语法
- JavaScript 注释
- Javascript 基本处理流程
- Javascript 选择结构
- Javascript if 语句
- Javascript if 语句的嵌套
- Javascript switch 语句
- Javascript 循环结构
- Javascript 循环结构实例
- Javascript 跳转语句
- Javascript 控制语句总结
- Javascript 函数介绍
- Javascript 函数的定义
- Javascript 函数调用
- Javascript 几种特殊的函数
- JavaScript 内置函数简介
- Javascript eval() 函数
- Javascript isFinite() 函数
- Javascript isNaN() 函数
- parseInt() 与 parseFloat()
- escape() 与 unescape()
- Javascript 字符串介绍
- Javascript length属性
- javascript 字符串函数
- Javascript 日期对象简介
- Javascript 日期对象用途
- Date 对象属性和方法
- Javascript 数组是什么
- Javascript 创建数组
- Javascript 数组赋值与取值
- Javascript 数组属性和方法
- n数之和题目要类比——LeetCode题目18:四数之和
- SpringBoot使用MySQL访问数据
- MySQL数据库与JDBC编程
- 自动删除QQ空间指定好友的留言
- 在Ubuntu 18.04中安装VMware工具
- 微信小程序下拉刷新功能
- 详解Linux Screen让程序保持后台运行
- Python Des加密与解密实现软件注册码、机器码
- Excel VBA 在保留原单元格数据的情况下,将计算的百分比加在后面
- 入门级别的面试题——LeetCode题目19:删除链表的倒数第N个节点
- python做web接口测试零散笔记--1
- 要一遍做对——LeetCode题目20:有效的括号
- 双指针算法练习(一)
- 一般是面试的热身题——LeetCode题目21:合并两个有序链表
- LeetCode题目22:括号生成