pytorch 中Dataloader中的collate_fn参数
时间:2019-08-15
本文章向大家介绍pytorch 中Dataloader中的collate_fn参数,主要包括pytorch 中Dataloader中的collate_fn参数使用实例、应用技巧、基本知识点总结和需要注意事项,具有一定的参考价值,需要的朋友可以参考一下。
一般的,默认的collate_fn函数是要求一个batch中的图片都具有相同size(因为要做stack操作),当一个batch中的图片大小都不同时,可以使用自定义的collate_fn函数,则一个batch中的图片不再被stack操作,可以全部存储在一个list中,当然还有对应的label,如下面这个例子:
import torch from torch.utils.data import DataLoader from torchvision import transforms import torchvision.datasets as datasets import matplotlib.pyplot as plt # a simple custom collate function, just to show the idea def my_collate(batch): data = [item[0] for item in batch] target = [item[1] for item in batch] target = torch.LongTensor(target) return [data, target] def show_image_batch(img_list, title=None): num = len(img_list) fig = plt.figure() for i in range(num): ax = fig.add_subplot(1, num, i+1) ax.imshow(img_list[i].numpy().transpose([1,2,0])) ax.set_title(title[i]) plt.show() # do not do randomCrop to show that the custom collate_fn can handle images of different size train_transforms = transforms.Compose([transforms.Scale(size = 224), transforms.ToTensor(), ]) # change root to valid dir in your system, see ImageFolder documentation for more info train_dataset = datasets.ImageFolder(root="/hd1/jdhao/toyset", transform=train_transforms) trainset = DataLoader(dataset=train_dataset, batch_size=4, shuffle=True, collate_fn=my_collate, # use custom collate function here pin_memory=True) trainiter = iter(trainset) imgs, labels = trainiter.next() # print(type(imgs), type(labels)) show_image_batch(imgs, title=[train_dataset.classes[x] for x in labels])
原文地址:https://www.cnblogs.com/zf-blog/p/11360557.html
- MySQL中的半同步复制(r11笔记第65天)
- Linux系统LVM逻辑卷创建过程以及自动化脚本
- 一个闪回区报警的数据恢复(r11笔记第62天)
- 利用腾讯云COS云对象存储定时远程备份网站
- 分享一个自写的Python远程命令和文件(夹)传输类
- Oracle数据误操作全面恢复实战(r11笔记第78天)
- 远程协助解决异常宕库的问题(r11笔记第75天)
- Nginx-helper纯代码版,文章评论发布自动清理Fastcgi缓存
- MySQL和Oracle行值表达式对比(r11笔记第74天)
- 闪回数据库不是“万金油”(r11笔记第73天)
- 修改Apache的超时设置,解决长连接请求超时问题
- Oracle 12cR2初体验(r11笔记第91天)
- MySQL中的undo截断(r11笔记第89天)
- Linux系统 df 命令显示异常、分区丢失问题解决
- JavaScript 教程
- JavaScript 编辑工具
- JavaScript 与HTML
- JavaScript 与Java
- JavaScript 数据结构
- JavaScript 基本数据类型
- JavaScript 特殊数据类型
- JavaScript 运算符
- JavaScript typeof 运算符
- JavaScript 表达式
- JavaScript 类型转换
- JavaScript 基本语法
- JavaScript 注释
- Javascript 基本处理流程
- Javascript 选择结构
- Javascript if 语句
- Javascript if 语句的嵌套
- Javascript switch 语句
- Javascript 循环结构
- Javascript 循环结构实例
- Javascript 跳转语句
- Javascript 控制语句总结
- Javascript 函数介绍
- Javascript 函数的定义
- Javascript 函数调用
- Javascript 几种特殊的函数
- JavaScript 内置函数简介
- Javascript eval() 函数
- Javascript isFinite() 函数
- Javascript isNaN() 函数
- parseInt() 与 parseFloat()
- escape() 与 unescape()
- Javascript 字符串介绍
- Javascript length属性
- javascript 字符串函数
- Javascript 日期对象简介
- Javascript 日期对象用途
- Date 对象属性和方法
- Javascript 数组是什么
- Javascript 创建数组
- Javascript 数组赋值与取值
- Javascript 数组属性和方法
- Android Notification 使用方法详解
- Android空心圆及层叠效果实现代码
- 如何更改Dialog的标题与按钮颜色详解
- Android编程之数据库的创建方法详解
- android studio集成ijkplayer的示例代码
- Android开发实现浏览器全屏显示功能
- Android动态人脸检测的示例代码(脸数可调)
- Android抽奖轮盘的制作方法
- Android 获取屏幕的多种宽高信息的示例代码
- Android编程实现禁止StatusBar下拉的方法
- Android自定义view圆并随手指移动
- Android仿微信发送语音消息的功能及示例代码
- 详解Android studio ndk配置cmake开发native C
- Android编程实现禁止状态栏下拉的方法详解
- Android进度条ProgressBar的实现代码