自定义PyTorch中的Sampler
时间:2022-07-28
本文章向大家介绍自定义PyTorch中的Sampler,主要内容包括其使用实例、应用技巧、基本知识点总结和需要注意事项,具有一定的参考价值,需要的朋友可以参考一下。
本文使用 Zhihu On VSCode 创作并发布
在训练GAN的过程中,一次只训练一个类别据说有助于模型收敛,但是PyTorch里面没有预设这种数据加载方式,要这样训练的话,需要自己定义Sampler,即自定义数据采样方式。下面是自定义的方法:
首先,我们虚构一个Dataset类,用于测试。
这个类中的label标签是混乱的,无法通过控制index范围来实现单类别训练。
class Data(Dataset):
def __init__(self):
self.img = torch.cat([torch.ones(2, 2) for i in range(100)], dim=0)
self.num_classes = 2
self.label = torch.tensor(
[random.randint(0, self.num_classes - 1) for i in range(100)]
)
def __getitem__(self, index):
return self.img[index], self.label[index]
def __len__(self):
return len(self.label)
然后,自定义一个Sampler类,这个类的作用是生成一个index列表,可以理解为重排data中的index。
class CustomSampler(Sampler):
def __init__(self, data):
self.data = data
def __iter__(self):
indices = []
for n in range(self.data.num_classes):
index = torch.where(self.data.label == n)[0]
indices.append(index)
indices = torch.cat(indices, dim=0)
return iter(indices)
def __len__(self):
return len(self.data)
定义好了之后可以封装成DataLoader并查看运行结果:
d = Data()
s = CustomSampler(d)
dl = DataLoader(d, 8, sampler=s)
for img, label in dl:
print(label)
结果
tensor([0, 0, 0, 0, 0, 0, 0, 0])
tensor([0, 0, 0, 0, 0, 0, 0, 0])
tensor([0, 0, 0, 0, 0, 0, 0, 0])
tensor([0, 0, 0, 0, 0, 0, 0, 0])
tensor([0, 0, 0, 0, 0, 0, 0, 0])
tensor([0, 0, 0, 0, 0, 0, 0, 0])
tensor([0, 0, 1, 1, 1, 1, 1, 1])
tensor([1, 1, 1, 1, 1, 1, 1, 1])
tensor([1, 1, 1, 1, 1, 1, 1, 1])
tensor([1, 1, 1, 1, 1, 1, 1, 1])
tensor([1, 1, 1, 1, 1, 1, 1, 1])
tensor([1, 1, 1, 1, 1, 1, 1, 1])
tensor([1, 1, 1, 1])
显然,这样的结果并不能让人满意,有一个batch中还是包含了两种不同类型的标签,为了达到目的,我们还需要再定义一个BatchSampler类
class CustomBatchSampler:
def __init__(self, sampler, batch_size, drop_last):
self.sampler = sampler
self.batch_size = batch_size
self.drop_last = drop_last
def __iter__(self):
batch = []
i = 0
sampler_list = list(self.sampler)
for idx in sampler_list:
batch.append(idx)
if len(batch) == self.batch_size:
yield batch
batch = []
if (
i < len(sampler_list) - 1
and self.sampler.data.label[idx]
!= self.sampler.data.label[sampler_list[i + 1]]
):
if len(batch) > 0 and not self.drop_last:
yield batch
batch = []
else:
batch = []
i += 1
if len(batch) > 0 and not self.drop_last:
yield batch
def __len__(self):
if self.drop_last:
return len(self.sampler) // self.batch_size
else:
return (len(self.sampler) + self.batch_size - 1) // self.batch_size
虽然PyTorch要求Sampler需要定义成一个迭代器,但是如果你自己定义BatchSampler的话,Sampler的形式可以自己定,就算写成一个普通的列表也没关系。
再次封装成DataLoader并查看运行结果:
d = Data()
s = CustomSampler(d)
bs = CustomBatchSampler(s, 8, False)
dl = DataLoader(d, batch_sampler=bs)
for img, label in dl:
print(label)
drop_last = False 的结果:
tensor([0, 0, 0, 0, 0, 0, 0, 0])
tensor([0, 0, 0, 0, 0, 0, 0, 0])
tensor([0, 0, 0, 0, 0, 0, 0, 0])
tensor([0, 0, 0, 0, 0, 0, 0, 0])
tensor([0, 0, 0, 0, 0, 0, 0, 0])
tensor([0, 0, 0, 0, 0])
tensor([1, 1, 1, 1, 1, 1, 1, 1])
tensor([1, 1, 1, 1, 1, 1, 1, 1])
tensor([1, 1, 1, 1, 1, 1, 1, 1])
tensor([1, 1, 1, 1, 1, 1, 1, 1])
tensor([1, 1, 1, 1, 1, 1, 1, 1])
tensor([1, 1, 1, 1, 1, 1, 1, 1])
tensor([1, 1, 1, 1, 1, 1, 1])
drop_last = True 的结果:
tensor([0, 0, 0, 0, 0, 0, 0, 0])
tensor([0, 0, 0, 0, 0, 0, 0, 0])
tensor([0, 0, 0, 0, 0, 0, 0, 0])
tensor([0, 0, 0, 0, 0, 0, 0, 0])
tensor([0, 0, 0, 0, 0, 0, 0, 0])
tensor([0, 0, 0, 0, 0, 0, 0, 0])
tensor([1, 1, 1, 1, 1, 1, 1, 1])
tensor([1, 1, 1, 1, 1, 1, 1, 1])
tensor([1, 1, 1, 1, 1, 1, 1, 1])
tensor([1, 1, 1, 1, 1, 1, 1, 1])
tensor([1, 1, 1, 1, 1, 1, 1, 1])
以上就是自定义Sampler的步骤了。
- Python标准库01 正则表达式 (re包)
- 剑指OFFER之栈的压入、弹出序列(九度OJ1366)
- Python标准库03 路径与文件 (os.path包, glob包)
- AI人工智能时代已经到来 “北斗即时判”实现纯语音交互
- 剑指OFFER之链表中倒数第k个节点(九度OJ1517)
- 用Qt写软件系列四:定制个性化系统托盘菜单
- Linux简介与厂商版本
- 用Qt写软件系列三:一个简单的系统工具之界面美化
- VS编译链接时错误(Error Link2005)的解决方法
- HttpClient使用心得
- 剑指OFFER之重建二叉树(九度OJ1385)
- 记录visual Studio使用过程中的两个问题
- 剑指OFFER之二维数组中的查找(九度OJ1384)
- Python标准库02 时间与日期 (time, datetime包)
- JavaScript 教程
- JavaScript 编辑工具
- JavaScript 与HTML
- JavaScript 与Java
- JavaScript 数据结构
- JavaScript 基本数据类型
- JavaScript 特殊数据类型
- JavaScript 运算符
- JavaScript typeof 运算符
- JavaScript 表达式
- JavaScript 类型转换
- JavaScript 基本语法
- JavaScript 注释
- Javascript 基本处理流程
- Javascript 选择结构
- Javascript if 语句
- Javascript if 语句的嵌套
- Javascript switch 语句
- Javascript 循环结构
- Javascript 循环结构实例
- Javascript 跳转语句
- Javascript 控制语句总结
- Javascript 函数介绍
- Javascript 函数的定义
- Javascript 函数调用
- Javascript 几种特殊的函数
- JavaScript 内置函数简介
- Javascript eval() 函数
- Javascript isFinite() 函数
- Javascript isNaN() 函数
- parseInt() 与 parseFloat()
- escape() 与 unescape()
- Javascript 字符串介绍
- Javascript length属性
- javascript 字符串函数
- Javascript 日期对象简介
- Javascript 日期对象用途
- Date 对象属性和方法
- Javascript 数组是什么
- Javascript 创建数组
- Javascript 数组赋值与取值
- Javascript 数组属性和方法
- [Oracle数据泵全解析]expdp交互式命令行模式命令
- SpringBoot Feign文件上传
- Docker_000
- 如何应对面试官的JVM调优问题
- Docker_001
- Docker_002
- [PyQt Tutorial]2.一个Hello World程序
- Oracle设置开机自启
- Go_学习之Docke容器
- zabbix 监控项
- [PyQt Tutorial]4.使用Qt Designer
- [PyQt Tutorial]5.Signals & Slots(信号与槽)
- Docker数据共享与持久化
- [PyQt Tutorial]6.Layout Management(布局管理)
- Kubernetes入门