PyTorch4:模块总览&torch.utils.data
1. Pytorch模块总览
相比TensorFlow,PyTorch 是非常轻量级的:相比 TensorFlow 追求兼容并包,PyTorch 把外围功能放在了扩展包中,比如torchtext,以保持主体的轻便。
根据PyTorch 的 API,可知其核心大概如下:
-
torch.nn
&torch.nn.functional
:构建神经网络 -
torch.nn.init
:初始化权重 -
torch.optim
:优化器 -
torch.utils.data
:载入数据
可以说,掌握了上面四个模块和前文中提到的底层 API,至少 80% 的 PyTorch 任务都可以完成。剩下的外围事物则有如下的模块支持:
-
torch.cuda
:管理 GPU 资源 -
torch.distributed
:分布式训练 -
torch.jit
:构建静态图提升性能 -
torch.tensorboard
:神经网络的可视化
如果额外掌握了上面的四个的模块,PyTorch 就只剩下一些边边角角的特殊需求了。
2.torch.utils.data
这个功能包的作用是收集、打包数据,给数据索引,然后按照 batch 将数据分批喂给神经网络。
数据读取的核心是 torch.utils.data.DataLoader
类。它是一个数据迭代读取器,支持
- 映射方式和迭代方式读取数据;
- 自定义数据读取顺序;
- 自动批;
- 单线程或多线程数据读取;
- 自动内存定位。
所有上述功能都可以在 torch.utils.data.DataLoader
的变量中定义:
DataLoader(dataset, batch_size=1, shuffle=False, sampler=None,
batch_sampler=None, num_workers=0, collate_fn=None,
pin_memory=False, drop_last=False, timeout=0,
worker_init_fn=None)
最重要的变量为 dataset
,它指明了数据的来源。
DataLoader 支持两种数据类型:
- 映射风格的数据封装(map-style datasets):这种数据结构拥有自定义的
__getitem__()
和__len__()
属性,可以以“索引/值”的方式读取数据,对应torch.utils.data.Dataset
类; - 迭代风格的数据封装(iterable-style datasets):这种数据结构拥有自定义的
__iter__()
属性,通常适用于不方便随机获取数据或不定长数据集的读取上,对应torch.utils.data.IterableDataset
类。
下面我们从顶层的 torch.utils.data.DataLoader
开始,然后一步一步深入到自定义的细节上。为了方便讨论,我们先人工构建一个数据集:
>>> samples = torch.arange(100)
>>> labels = torch.cat([torch.zeros(50), torch.ones(50)], dim=0)
2.1 torch.utils.data.DataLoader 数据加载器
首先看一下常用的变量:
-
dataset
:数据源; -
batch_size
:一个整数,定义每一批读取的元素个数; -
shuffle
:一个布尔值,定义是否随机读取; -
sampler
:定义获取数据的策略,必须与shuffle
互斥; -
num_workers
:一个整数,读取数据使用的线程数; -
collate_fn
:一个将读取的数据处理、聚合成一个一个 batch 的自定义函数; -
drop_last
:一个布尔值,如果最后一批数据的个数不足 batch 的大小,是否保留这个 batch。
dataset
, sampler
和 collate_fn
是自定义的类或功能,我们从后往前看。
2.2 数据集的分割
在介绍这三个变量以前,我们先看看如何将数据集分割,比如分成训练集和测试集。
torch.utils.data.Subset(dataset, indices)
这个函数可以根据索引indices将数据集dataset分割。
>>> even = [i for i in range(100) if i % 2 == 0]
>>> new1 = torch.utils.data.Subset(samples, even)
>>> print(new1[:5])
tensor([0, 2, 4, 6, 8])
torch.utils.data.random_split(dataset, lengths)
先将数据随机排列,然后按照指定的长度进行选择。长度的和必须等于数据集中的数据数量。
>>> train, test = torch.utils.data.random_split(samples, [90, 10])
>>> print(torch.tensor(test))
tensor([79, 60, 98, 74, 31, 43, 21, 69, 55, 76])
2.3. collate_fn 核对函数
这个变量的功能是在数据被读取后,送进模型前对所有数据进行处理、打包。
比如我们有一个不定长度的视频数据集或文本数据集,我们可以自定义一个函数将它们的长度归一化。比如:
>>> a = [[1,2,3],[4,5],[6,7,8,9]]
>>> def collate_fn(data):
... '''
... padding data, so they have same length.
... '''
... max_len = max([len(feature) for feature in data])
... new = torch.zeros(len(data), max_len)
... for i in range(len(data)):
... tmp = torch.as_tensor(data[i])
... j = len(tmp)
... new[i][:j] = tmp
... return new
>>> collate_fn(a)
tensor([[1., 2., 3., 0.],
[4., 5., 0., 0.],
[6., 7., 8., 9.]])
将这个函数赋值给 collate_fn
,在读取数据的时候就可以自动对数据进行 padding 并打包成一个 batch。
2.4 sampler 采样器
这个变量决定了数据读取的顺序。
注意,sampler
只对 iterable-style datasets 有效。
除了可以自定义采样器,Python 内置了几种不同的采样器:
-
torch.utils.data.SequentialSampler(data_source)
默认的采样器。 torch.utils.data.RandomSampler(data_source, replacement=False, num_samples=None)
随机选择数据。可以指定一次读取 num_samples
个数据。replacement
为 True
的话可以指定 num_samples
。
>>> batch = torch.utils.data.RandomSampler(samples, replacement=True, num_samples=5) # 生成一个迭代器
>>> print(list(batch))
[85, 70, 5, 63, 79]
还有三个采样器无法独立使用,必须先实例化,然后放进 DataLoader
:
-
torch.utils.data.SubsetRandomSampler(indices)
:先按照索引选取数据,然后随机排列。 -
torch.utils.data.WeightedRandomSampler(weights, num_samples, replacement=True)
:字面意思是按照概率选择不同类别的元素。 -
torch.utils.data.BatchSampler(sampler, batch_size, drop_last)
:在一个 batch 中应用另外一个采样器。
2.5 dataset 数据集生成器
torch.utils.data.Dataset
这个类需要覆写 __getitem__
和 __len__
属性。
>>> class MyData(torch.utils.data.Dataset):
... def __init__(self, data):
... super(MyData, self).__init__()
... self.data = data
... def __len__(self, data):
... return len(self.data)
... def __getitem__(self, index):
... return self.data[index]
>>> mydata = MyData(samples)
>>> mydata[0]
tensor(0)
>>> mydata[10:15]
tensor([10, 11, 12, 13, 14])
除此以外,还有若干个 wrapper:
torch.utils.data.IterableDataset
torch.utils.data.TensorDataset(*tensors)
torch.utils.data.ConcatDataset(datasets)
torch.utils.data.ChainDataset(datasets)
2.6 总结
选择让我们把所有知识应用一下。假设我们想以 10 为一个 batch,随机选择数据:
>>> train = data.TensorDataset(torch.as_tensor(samples), torch.as_tensor(labels))
>>> ds = torch.utils.data.DataLoader(train, batch_size=10, shuffle=True)
>>> for _ in range(5):
... print(iter(ds).next())
[tensor([35, 19, 99, 58, 59, 10, 26, 86, 24, 25]), tensor([0., 0., 1., 1., 1., 0., 0., 1., 0., 0.])]
[tensor([ 6, 37, 24, 98, 96, 18, 88, 90, 19, 87]), tensor([0., 0., 0., 1., 1., 0., 1., 1., 0., 1.])]
[tensor([80, 75, 48, 34, 90, 67, 8, 63, 47, 32]), tensor([1., 1., 0., 0., 1., 1., 0., 1., 0., 0.])]
[tensor([48, 68, 64, 54, 87, 76, 18, 53, 65, 17]), tensor([0., 1., 1., 1., 1., 1., 0., 1., 1., 0.])]
[tensor([65, 26, 67, 5, 4, 8, 35, 47, 40, 96]), tensor([1., 0., 1., 0., 0., 0., 0., 0., 0., 1.])]
- Python中对字节流/二进制流的操作:struct模块简易使用教程
- C++ 后台程序实时性能监控
- 系统入侵后的排查思路及心得
- 记一次Linux被入侵的经历
- C++ FFLIB之ffcount:通用数据分析系统
- Python内置数据结构之迭代器知多少?
- Python之解析式您知多少?
- C++ FFLIB 之FFDB: 使用 Mysql&Sqlite 实现CRUD
- C++ FFLIB之FFXML: 极简化TinyXml 读取
- 架构高性能网站秘笈(五)——Web组件分离
- 安全编程-c++野指针和内存泄漏
- 稳扎稳打JS——this
- FFLIb Demo && CQRS
- springcloud学习手册-Eureka常见问题总结
- JavaScript 教程
- JavaScript 编辑工具
- JavaScript 与HTML
- JavaScript 与Java
- JavaScript 数据结构
- JavaScript 基本数据类型
- JavaScript 特殊数据类型
- JavaScript 运算符
- JavaScript typeof 运算符
- JavaScript 表达式
- JavaScript 类型转换
- JavaScript 基本语法
- JavaScript 注释
- Javascript 基本处理流程
- Javascript 选择结构
- Javascript if 语句
- Javascript if 语句的嵌套
- Javascript switch 语句
- Javascript 循环结构
- Javascript 循环结构实例
- Javascript 跳转语句
- Javascript 控制语句总结
- Javascript 函数介绍
- Javascript 函数的定义
- Javascript 函数调用
- Javascript 几种特殊的函数
- JavaScript 内置函数简介
- Javascript eval() 函数
- Javascript isFinite() 函数
- Javascript isNaN() 函数
- parseInt() 与 parseFloat()
- escape() 与 unescape()
- Javascript 字符串介绍
- Javascript length属性
- javascript 字符串函数
- Javascript 日期对象简介
- Javascript 日期对象用途
- Date 对象属性和方法
- Javascript 数组是什么
- Javascript 创建数组
- Javascript 数组赋值与取值
- Javascript 数组属性和方法
- Flask搭建ES搜索引擎(二)
- Java 通过RestHighLevelClient 使用ES的date_histogram 根据年月日做统计
- Debug HashMap
- NSum及股票系列
- 从0开始搭建编程框架——主框架和源码
- C++拾取——使用stl标准库生成等差、等比数列的方法
- C++拾取——使用stl标准库实现排序算法及评测
- 机器人实践课程镜像分享及使用说明(Arduino+ROS1+ROS2+Gazebo+SLAM+...)
- ROS 2 Foxy Fitzroy遇见Ubuntu 20.04
- ROS Noetic Ninjemys遇见Ubuntu 20.04
- Kustomize ConfigMapGenerate自动生成ConfigMap中的坑
- ThreadLocal的使用及原理
- 参数绑定
- ndn挖坑记(一)
- Python之QQ邮箱告警脚本