【深度学习】PyTorch之Squeeze()和Unsqueeze()
时间:2019-10-15
本文章向大家介绍【深度学习】PyTorch之Squeeze()和Unsqueeze(),主要包括【深度学习】PyTorch之Squeeze()和Unsqueeze()使用实例、应用技巧、基本知识点总结和需要注意事项,具有一定的参考价值,需要的朋友可以参考一下。
1. unsqueeze()
该函数用来增加某个维度。在PyTorch中维度是从0开始的。
import torch a = torch.arange(0, 9) print(a)
结果:
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8])
利用view()改变tensor的形状。值得注意的是view不会修改自身的数据,返回的新tensor与源tensor共享内存;同时必须保证前后元素总数一致。
a = a.view(3, 3) print(f"a:{a} \n shape:{a.shape}")
结果:
a:tensor([[0, 1, 2], [3, 4, 5], [6, 7, 8]]) shape:torch.Size([3, 3])
在第一个维度(即维度序号为0)前增加一个维度。
a = a.unsqueeze(0) print(f"a:{a}\nshape:{a.shape}")
结果:
a:tensor([[[0, 1, 2], [3, 4, 5], [6, 7, 8]]]) shape:torch.Size([1, 3, 3])
同理,可在其他位置添加维度,在这里就不举例了。
2. squeeze()
该函数用来减少某个维度。
print(f"1. a:{a}\nshape:{a.shape}") a = a.unsqueeze(0) a = a.unsqueeze(2) print(f"2. a:{a}\nshape:{a.shape}") a = a.squeeze(2) print(f"3. a:{a}\nshape:{a.shape}")
结果:
1. a:tensor([[0, 1, 2], [3, 4, 5], [6, 7, 8]]) shape:torch.Size([3, 3]) 2. a:tensor([[[[0, 1, 2]], [[3, 4, 5]], [[6, 7, 8]]]]) shape:torch.Size([1, 3, 1, 3]) 3. a:tensor([[[0, 1, 2], [3, 4, 5], [6, 7, 8]]]) shape:torch.Size([1, 3, 3])
3. 下面是运用上述两个函数,并进行一次卷积的例子。
from torchvision.transforms import ToTensor import torch as t from torch import nnimport cv2 import numpy as np import cv2
to_tensor = ToTensor() # 加载图像 lena = cv2.imread('lena.jpg', cv2.IMREAD_GRAYSCALE) cv2.imshow('lena', lena) input = to_tensor(lena) input = to_tensor(lena).unsqueeze(0) # 初始化卷积参数 kernel = t.ones(1, 1, 3, 3)/-9 kernel[:, :, 1, 1] = 1 conv = nn.Conv2d(1, 1, 3, 1, padding=1, bias=False) conv.weight.data = kernel.view(1, 1, 3, 3) # 输出 out = conv(input) out = out.squeeze(0) print(out.shape) out = out.unsqueeze(3) print(out.shape) out = out.squeeze(0) print(out.shape) out = out.detach().numpy()
# 缩放到0~最大值 cv2.normalize(out, out, 1.0, 0, cv2.NORM_INF) cv2.imshow("lena-result", out) cv2.waitKey()
结果:
torch.Size([1, 304, 304]) torch.Size([1, 304, 304, 1]) torch.Size([304, 304, 1]) <class 'numpy.ndarray'> (304, 304, 1)
原文地址:https://www.cnblogs.com/chen-hw/p/11678949.html
- 编程思想 之「语言导论」
- 编程思想 之「对象漫谈」
- Github 项目推荐 | TensorFlow 概率推理工具集 —— probability
- Github 项目推荐 | 用于 C/C++、Java、Matlab/Octave 的特征选择工具箱
- Mercari Price 比赛分享 —— 语言不仅是算法和公式而已
- Github 项目推荐 | GAN 的 Keras 实现案例集合 —— Keras-GAN
- Github 项目推荐 | 微软开源 MMdnn,模型可在多框架间转换
- 半自动化运维之动态添加数据文件(一) (r5笔记第55天)
- 半自动化运维之动态添加数据文件(二) (r5笔记第56天)
- 11g Active DataGuard初探(r5笔记第54天)
- Github 项目推荐 | 用于构建端对端对话系统和训练聊天机器人的开源库 —— DeepPavlov
- 我身边的一些数据库事故 (r5笔记第52天)
- 一个清理脚本的改进思路(r5笔记第51天)
- 【专业技术】Python爬虫:抓取手机APP的传输数据
- JavaScript 教程
- JavaScript 编辑工具
- JavaScript 与HTML
- JavaScript 与Java
- JavaScript 数据结构
- JavaScript 基本数据类型
- JavaScript 特殊数据类型
- JavaScript 运算符
- JavaScript typeof 运算符
- JavaScript 表达式
- JavaScript 类型转换
- JavaScript 基本语法
- JavaScript 注释
- Javascript 基本处理流程
- Javascript 选择结构
- Javascript if 语句
- Javascript if 语句的嵌套
- Javascript switch 语句
- Javascript 循环结构
- Javascript 循环结构实例
- Javascript 跳转语句
- Javascript 控制语句总结
- Javascript 函数介绍
- Javascript 函数的定义
- Javascript 函数调用
- Javascript 几种特殊的函数
- JavaScript 内置函数简介
- Javascript eval() 函数
- Javascript isFinite() 函数
- Javascript isNaN() 函数
- parseInt() 与 parseFloat()
- escape() 与 unescape()
- Javascript 字符串介绍
- Javascript length属性
- javascript 字符串函数
- Javascript 日期对象简介
- Javascript 日期对象用途
- Date 对象属性和方法
- Javascript 数组是什么
- Javascript 创建数组
- Javascript 数组赋值与取值
- Javascript 数组属性和方法
- 没想到,几行代码,你就可以实现图片压缩(springboot)!
- Go 语言学习之基础数据类型
- Go 语言学习之变量
- Go 泛型的括号选择:[ ] or ( )?
- 如何使用GitLab CI/CD 触发多项目管道
- 使用docker数据卷持久化容器数据
- 面向初学者的Docker快速入门指南
- TypeScript:React、拖拽、实践!
- 太慢不能忍!CPU又拿硬盘和网卡开刀了!
- 懂了!VMware/KVM/Docker原来是这么回事儿
- CPU明明8个核,网卡为啥拼命折腾一号核?
- 2020-07-22-腾讯云-slb-kubeadm高可用集群搭建
- 2020-07-23-kubernetes集群使用腾讯云cbs块存储
- 十一、详解面向对象
- 十二、面向对象实战之封装拖拽对象