首页 > 其他 > 详细

pytorch nll_loss与CrossEntropyLoss示例

时间:2021-04-01 00:34:44      阅读:33      评论:0      收藏:0      [点我收藏+]
nll_loss 输入的是求过log_softmax之后的值,默认reduction=‘mean‘,计算的是平均loss,即将targets对应的log_prob相加再求均值:
>>> import torch
>>> logits=torch.randn(2,3)
>>> logits
tensor([[-0.1818, -1.2657,  1.6381],
        [ 0.2038,  0.2661, -2.3768]])


>>> probs=torch.nn.functional.log_softmax(logits, dim=-1)
>>> probs
tensor([[-2.0162, -3.1000, -0.1963],
        [-0.7608, -0.6985, -3.3414]])


>>> targets=torch.tensor([0,1])
>>> targets
tensor([0, 1])


>>> loss = torch.nn.functional.nll_loss(probs, targets)
>>> loss
tensor(1.3574)


>>> loss = torch.nn.functional.nll_loss(probs, targets, reduction=none)
>>> loss
tensor([2.0162, 0.6985])


>>> loss = torch.nn.functional.nll_loss(probs, targets, reduction=mean)
>>> loss
tensor(1.3574)


>>> loss = torch.nn.functional.nll_loss(probs, targets, reduction=sum)
>>> loss
tensor(2.7147)


>>> loss = torch.nn.functional.nll_loss(probs, targets, reduction=none)
>>> loss.mean()
tensor(1.3574)

 

交叉熵Loss,输入是没有做softmax的logits,torch.nn.CrossEntropyLoss() 计算结果= log_softmax 结合 nll_loss

>>> ce = torch.nn.CrossEntropyLoss()
>>> loss = ce(logits, targets)
>>> loss
tensor(1.3573)

 

pytorch nll_loss与CrossEntropyLoss示例

原文:https://www.cnblogs.com/AliceYing/p/14603346.html

(0)
(0)
   
举报
评论 一句话评论(0
关于我们 - 联系我们 - 留言反馈 - 联系我们:wmxa8@hotmail.com
© 2014 bubuko.com 版权所有
打开技术之扣,分享程序人生!