首页 > 其他 > 详细

pytorch中tensor.mean(axis, keepdim)

时间:2020-09-05 13:39:17      阅读:368      评论:0      收藏:0      [点我收藏+]
 1 import numpy as np
 2 import torch
 3 
 4 x=[
 5 [[1,2,3,4],
 6  [5,6,7,8],
 7  [9,10,11,12]],
 8 
 9 [[13,14,15,16],
10  [17,18,19,20],
11  [21,22,23,24]]
12 ]
13 x=torch.tensor(x).float()
14 #
15 print("shape of x:")  ##[2,3,4]
16 print(x.shape)
17 #
18 print("shape of x.mean(axis=0,keepdim=True):")          #[1, 3, 4]
19 print(x.mean(axis=0,keepdim=True).shape)
20 print(x.mean(axis=0,keepdim=True))
21 #
22 print("shape of x.mean(axis=0,keepdim=False):")         #[3, 4]
23 print(x.mean(axis=0,keepdim=False).shape)
24 print(x.mean(axis=0,keepdim=False))
25 #
26 print("shape of x.mean(axis=1,keepdim=True):")          #[2, 1, 4]
27 print(x.mean(axis=1,keepdim=True).shape)
28 print(x.mean(axis=1,keepdim=True))
29 #
30 print("shape of x.mean(axis=1,keepdim=False):")         #[2, 4]
31 print(x.mean(axis=1,keepdim=False).shape)
32 print(x.mean(axis=1,keepdim=False))
33 #
34 print("shape of x.mean(axis=2,keepdim=True):")          #[2, 3, 1]
35 print(x.mean(axis=2,keepdim=True).shape)
36 print(x.mean(axis=2,keepdim=True))
37 #
38 print("shape of x.mean(axis=2,keepdim=False):")         #[2, 3]
39 print(x.mean(axis=2,keepdim=False).shape)
40 print(x.mean(axis=2,keepdim=False))

 

shape of x:
torch.Size([2, 3, 4])
shape of x.mean(axis=0,keepdim=True):
torch.Size([1, 3, 4])
tensor([[[ 7.,  8.,  9., 10.],
         [11., 12., 13., 14.],
         [15., 16., 17., 18.]]])
shape of x.mean(axis=0,keepdim=False):
torch.Size([3, 4])
tensor([[ 7.,  8.,  9., 10.],
        [11., 12., 13., 14.],
        [15., 16., 17., 18.]])
shape of x.mean(axis=1,keepdim=True):
torch.Size([2, 1, 4])
tensor([[[ 5.,  6.,  7.,  8.]],

        [[17., 18., 19., 20.]]])
shape of x.mean(axis=1,keepdim=False):
torch.Size([2, 4])
tensor([[ 5.,  6.,  7.,  8.],
        [17., 18., 19., 20.]])
shape of x.mean(axis=2,keepdim=True):
torch.Size([2, 3, 1])
tensor([[[ 2.5000],
         [ 6.5000],
         [10.5000]],

        [[14.5000],
         [18.5000],
         [22.5000]]])
shape of x.mean(axis=2,keepdim=False):
torch.Size([2, 3])
tensor([[ 2.5000,  6.5000, 10.5000],
        [14.5000, 18.5000, 22.5000]])

 

keepdim=True
运算完之后的维度和原来一样,原来是三维数组现在还是三维数组(不过某一维度变成了1);

keepdim=False
运算完之后一般少一维度,求平均变为1的那一维没有了;

axis=k
按第k维运算,其他维度不遍,第k维变为1

 

pytorch中tensor.mean(axis, keepdim)

原文:https://www.cnblogs.com/tingtin/p/13617470.html

(0)
(0)
   
举报
评论 一句话评论(0
关于我们 - 联系我们 - 留言反馈 - 联系我们:wmxa8@hotmail.com
© 2014 bubuko.com 版权所有
打开技术之扣,分享程序人生!