首页 > 其他 > 详细

pytorch中y.data.norm()的含义

时间:2020-06-25 18:35:22      阅读:652      评论:0      收藏:0      [点我收藏+]
import torch
x = torch.randn(3, requires_grad=True)
y = x*2
print(y.data.norm())
print(torch.sqrt(torch.sum(torch.pow(y,2))))  #其实就是对y张量L2范数,先对y中每一项取平方,之后累加,最后取根号
i=0
while y.data.norm()<1000:
  y = y*2
  i+=1
print(y)
print(i)

结果:

tensor(3.7025)
tensor(3.7025, grad_fn=<SqrtBackward>)
tensor([ 1066.4563, -1511.3652,  -414.6933], grad_fn=<MulBackward0>)
9

 

pytorch中y.data.norm()的含义

原文:https://www.cnblogs.com/peixu/p/13192265.html

(0)
(0)
   
举报
评论 一句话评论(0
关于我们 - 联系我们 - 留言反馈 - 联系我们:wmxa8@hotmail.com
© 2014 bubuko.com 版权所有
打开技术之扣,分享程序人生!