首页 > 其他 > 详细

torch 中的损失函数

时间:2020-07-20 19:46:02      阅读:94      评论:0      收藏:0      [点我收藏+]
  1. NLLLoss 和 CrossEntropyLoss
    在图片单标签分类时,输入m张图片,输出一个m*N的Tensor,其中N是分类个数。比如输入3张图片,分3类,最后的输出是一个3*3的Tensor
input = torch.tensor([[-0.1123, -0.6028, -0.0450],
              [ 0.1596,  0.2215, -1.0176],
              [-0.2359, -0.7898,  0.7097]])

第123行分别是第123张图片的结果,假设第123列分别是猫、狗和猪的分类得分。
first step: 对每一行使用Softmax,这样可以得到每张图片的概率分布。概率最大的为:1:猪;2:狗;3:猪。

sm = torch.nn.Softmax(dim=1)
sm(input)
tensor([[0.3729, 0.2283, 0.3988],
        [0.4216, 0.4485, 0.1299],
        [0.2410, 0.1385, 0.6205]])

second step: 对softmax结果取对数

torch.log(sm(input))
tensor([[-0.9865, -1.4770, -0.9192],
        [-0.8637, -0.8019, -2.0409],
        [-1.4229, -1.9767, -0.4773]])

Softmax后的数值都在0~1之间,所以log之后值域是负无穷到0。
NLLLoss的结果就是把上面的输出与Label对应的那个值拿出来,再去掉负号,再求均值。
假设我们现在Target是[0,2,1](第一张图片是猫,第二张是猪,第三张是狗)。第一行取第0个元素,第二行取第2个,第三行取第1个,去掉负号,结果是:[0.9865,2.0409,1.9767]。再求个均值,结果是:1.66
对比NLLLoss的结果

loss = torch.nn.NLLLoss()
loss(torch.log(sm(input)),target)
# 1.6681

CrossEntropyLoss 相当于上述步骤的组合,Softmax–Log–NLLLoss合并成一步

loss2 = torch.nn.CrossEntropyLoss()
loss2(input,target)
# 1.6681

torch 中的损失函数

原文:https://www.cnblogs.com/leimu/p/13346372.html

(0)
(0)
   
举报
评论 一句话评论(0
关于我们 - 联系我们 - 留言反馈 - 联系我们:wmxa8@hotmail.com
© 2014 bubuko.com 版权所有
打开技术之扣,分享程序人生!