首页 > 其他 > 详细

Pytorch tensor维度变化

时间:2021-08-20 21:26:49      阅读:22      评论:0      收藏:0      [点我收藏+]

发现当我使用DataLoader加载数据的时候使用Module进行前向传播是可以的,但是如果仅仅是对一个img(三维)进行前项传播是不可以的。

RuntimeError: Expected 4-dimensional input for 4-dimensional weight [6, 3, 2, 2], 
but got 3-dimensional input of size [3, 32, 32] instead

发现Dataloader有一个批处理,使得其一个tensor里面包含多个图片,tensor是四维的。

1.增加维度

a = torch.randn(2, 28, 28)

import torch
a = torch.randn(3, 32, 32)
print(a.shape)
print(a.unsqueeze(0).shape)
print(a.unsqueeze(1).shape)
print(a.unsqueeze(2).shape)
print(a.unsqueeze(3).shape)
print(a.unsqueeze(-1).shape)
print(a.unsqueeze(-2).shape)
print(a.unsqueeze(-3).shape)
print(a.unsqueeze(-4).shape)
print(a.unsqueeze(4).shape)

结果:

技术分享图片

2. 删除维度

维度删除的功能并不能做到删除任意维度的数据,只能删除那些size为1的维度

import torch

a = torch.Tensor(1, 4, 1, 9)
print(a.shape)
print(a.squeeze().shape)
print(a.squeeze(0).shape))# 0号维度是1,因此能删除
print(a.squeeze(1).shape)# 1号维度是4,因此不能删除
print(a.squeeze(2).shape)
print(a.squeeze(3).shape)# 3号维度是9,因此不能删除

显示结果:
技术分享图片

详细可见

Pytorch tensor维度变化

原文:https://www.cnblogs.com/xvxing/p/15168093.html

(0)
(0)
   
举报
评论 一句话评论(0
关于我们 - 联系我们 - 留言反馈 - 联系我们:wmxa8@hotmail.com
© 2014 bubuko.com 版权所有
打开技术之扣,分享程序人生!