Pytorch tensor维度变化

时间:2021-08-20
本文章向大家介绍Pytorch tensor维度变化,主要包括Pytorch tensor维度变化使用实例、应用技巧、基本知识点总结和需要注意事项,具有一定的参考价值,需要的朋友可以参考一下。

发现当我使用DataLoader加载数据的时候使用Module进行前向传播是可以的,但是如果仅仅是对一个img(三维)进行前项传播是不可以的。

RuntimeError: Expected 4-dimensional input for 4-dimensional weight [6, 3, 2, 2], 
but got 3-dimensional input of size [3, 32, 32] instead

发现Dataloader有一个批处理,使得其一个tensor里面包含多个图片,tensor是四维的。

1.增加维度

a = torch.randn(2, 28, 28)

import torch
a = torch.randn(3, 32, 32)
print(a.shape)
print(a.unsqueeze(0).shape)
print(a.unsqueeze(1).shape)
print(a.unsqueeze(2).shape)
print(a.unsqueeze(3).shape)
print(a.unsqueeze(-1).shape)
print(a.unsqueeze(-2).shape)
print(a.unsqueeze(-3).shape)
print(a.unsqueeze(-4).shape)
print(a.unsqueeze(4).shape)

结果:

2. 删除维度

维度删除的功能并不能做到删除任意维度的数据,只能删除那些size为1的维度

import torch

a = torch.Tensor(1, 4, 1, 9)
print(a.shape)
print(a.squeeze().shape)
print(a.squeeze(0).shape))# 0号维度是1,因此能删除
print(a.squeeze(1).shape)# 1号维度是4,因此不能删除
print(a.squeeze(2).shape)
print(a.squeeze(3).shape)# 3号维度是9,因此不能删除

显示结果:

详细可见

原文地址:https://www.cnblogs.com/xvxing/p/15168093.html