首页 > 其他 > 详细

Pytorch基础学习教程——切片与索引

时间:2021-05-08 16:32:52      阅读:18      评论:0      收藏:0      [点我收藏+]

一、获取Tensor中的部分数据:切片

技术分享图片
 1 #加载需要的库文件
 2 import torch 
 3 import torch.nn as nn 
 4 #输入Tensor : A = [batch_size,channels,width,height]
 5 A = torch.rand(16,3,512,512) # 表示16张行列号均为512的彩色图像
 6 #获取单张图像shape--->[channels,width,height]:
 7 img_shape = A[0].shape
 8 print("img_shape : ",img_shape)
 9 #获取图像的行列号--->[width,height]:
10 img_Size = A[0][0].shape
11 print("img_Size : ",img_Size)
12 #获取batch内第0个通道数据,冒号代表选择全部:
13 zero_imgs = A[:,0,:,:]
14 print("zero_imgs shape : ",zero_imgs.shape)
15 #获取batch内属于第0和第1通道数据,冒号代表选择全部
16 batch_two = A[:,:2,:,:]
17 print("batch_two shape : ",batch_two.shape)
18 #获取batch内属于第3通道的数据
19 batch_last = A[:,-1:,:,:]     #负号表示从后向前获取
20 batch_Last = A[:,2:,:,:]
21 print("batch_last shape : ",batch_last.shape)
22 print("batch_Last shape : ",batch_Last.shape)
23 #获取batch内前5张图像
24 batch_img = A[:5,:,:,:]
25 print("batch_img : ",batch_img.shape)
26 #以步长为2获取batch内的图像数据
27 batch_data = A[:,:,0:512:2,0:512:2]
28 batch_Data = A[:,:,::2,::2]
29 print("batch_data shape  : ",batch_data.shape)
30 print("batch_Data shape  : ",batch_Data.shape)
View Code

结果为:

技术分享图片
1 img_shape :  torch.Size([3, 512, 512])
2 img_Size :  torch.Size([512, 512])
3 zero_imgs shape :  torch.Size([16, 512, 512])
4 batch_two shape :  torch.Size([16, 2, 512, 512])
5 batch_last shape :  torch.Size([16, 1, 512, 512])
6 batch_Last shape :  torch.Size([16, 2, 512, 512])
7 batch_img :  torch.Size([5, 3, 512, 512])
8 batch_data shape  :  torch.Size([16, 3, 256, 256])
9 batch_Data shape  :  torch.Size([16, 3, 256, 256])
View Code

二、利用数据索引回去Tensor

torch.index_select(input, dim, index, out=None)

input : 输入的Tensor

dim : 索引的轴

index :索引的轴的索引

out : 目标Tensor

技术分享图片
 1 A = torch.rand(3,4)
 2 print(A)
 3 #获取第一行,第三行数据
 4 indices = torch.LongTensor([0,2])
 5 data = torch.index_select(A, 0, indices)
 6 print(data)
 7 #获取第三列数据
 8 indices = torch.LongTensor([2])
 9 data = torch.index_select(A, 1, indices)
10 print(data)
View Code

 结果为:

技术分享图片
 1 tensor([[0.3413, 0.4133, 0.8881, 0.3013],
 2         [0.2296, 0.8172, 0.6642, 0.0631],
 3         [0.0300, 0.5728, 0.0011, 0.6917]])
 4         
 5 tensor([[0.3413, 0.4133, 0.8881, 0.3013],
 6         [0.0300, 0.5728, 0.0011, 0.6917]])
 7         
 8 tensor([[0.8881],
 9         [0.6642],
10         [0.0011]])
View Code

mask_select会将满足mask(掩码)的指示,将满足条件的点选出来。将取值返回到一个新的1D张量

技术分享图片
1 A = torch.rand(3,4)
2 print(A)
3 mask = A.ge(0.5)
4 mask_data = torch.masked_select(A, mask)
5 print(mask_data)
View Code

结果为:

技术分享图片
1 tensor([[0.3891, 0.6926, 0.3271, 0.9869],
2         [0.4256, 0.9135, 0.3202, 0.8631],
3         [0.2141, 0.6315, 0.4464, 0.4018]])
4         
5  tensor([0.6926, 0.9869, 0.9135, 0.8631, 0.6315])
View Code

三、其他操作

torch.cat(inputs, dimension=0) → Tensor

inputs (sequence of Tensors)

dimension (int, optional) – 沿着此维连接张量序列。

在给定维度上对输入的张量序列 seq 进行连接操作

torch.nonzero(input, out=None) → LongTensor

input (Tensor) – 源张量

out (LongTensor, optional) – 包含索引值的结果张量

torch.split(tensor, split_size, dim=0)

tensor (Tensor) – 待分割张量

split_size (int) – 单个分块的形状大小

dim (int) – 沿着此维进行分割

torch.squeeze(input, dim=None, out=None)

input (Tensor) – 输入张量

dim (int, optional) – 如果给定,则input只会在给定维度挤压

out (Tensor, optional) – 输出张量

torch.stack(sequence, dim=0)

沿着一个新维度对输入张量序列进行连接。序列中所有的张量都应该为相同 形状

 

声明

本文是很久之前所记笔记,如有侵权,并非有意,可联系我进行删除!

 

Pytorch基础学习教程——切片与索引

原文:https://www.cnblogs.com/yzj-notes/p/14743935.html

(0)
(0)
   
举报
评论 一句话评论(0
关于我们 - 联系我们 - 留言反馈 - 联系我们:wmxa8@hotmail.com
© 2014 bubuko.com 版权所有
打开技术之扣,分享程序人生!