首页 > 其他 > 详细

PyTorch中的CrossEntropyLoss和BCEWithLogitcsLoss

时间:2021-04-22 09:13:29      阅读:35      评论:0      收藏:0      [点我收藏+]

CrossEntropyLoss=LogSoftMax+NLLLoss

BCEWithLogitcsLoss=Sigmoid+BCELoss

https://pytorch.org/docs/master/generated/torch.nn.CrossEntropyLoss.html

https://pytorch.org/docs/master/generated/torch.nn.BCEWithLogitsLoss.html

 

这里主要想说下计算损失时label怎么喂

Example1

>>> loss = nn.CrossEntropyLoss()
>>> input = torch.randn(3, 5, requires_grad=True)
>>> target = torch.empty(3, dtype=torch.long).random_(5)
>>> output = loss(input, target)
>>> output.backward()

Example2

>>> loss = nn.BCEWithLogitsLoss()
>>> input = torch.randn(3, requires_grad=True)
>>> target = torch.empty(3).random_(2)
>>> output = loss(input, target)
>>> output.backward()

所以CrossEntropy的label大小为N*C,其中N是Batch,C是Class的种类

BCEWithLogitsLoss看来似乎也是这样的。但实际上它的label是可以为小数的。

比如lable是[1, 0, 1],共三个样本,有两种理解形式,一种是第一个样本是正,第二个样本是负,三个样本是正。

第二种理解是三个样本属于正样本的概率分别是100%,0, 100%

按照这样的方式,可以有下面的例子:

Example3

import torch.nn as nn
import torch
import math

m = nn.Sigmoid()
loss = nn.BCELoss()
input = torch.randn(3, requires_grad=True)
#target = torch.empty(3).random_(2)
target = torch.tensor([0.7, 0.3, 0.5])#理解成三个样本属于正样本的概率分别是0.7,0.3,0.5。注意,这里是groundtruth。
prob=m(input)
l1=-(target[0]*math.log(prob[0])+(1-target[0])*math.log(1-prob[0]))
l2=-(target[1]*math.log(prob[1])+(1-target[1])*math.log(1-prob[1]))
l3=-(target[2]*math.log(prob[2])+(1-target[2])*math.log(1-prob[2]))
lossAvg=(l1+l2+l3)/3
print(‘lossAvg=‘+str(lossAvg))
output = loss(m(input), target)
print(‘output=‘+str(output))

 

 

PyTorch中的CrossEntropyLoss和BCEWithLogitcsLoss

原文:https://www.cnblogs.com/qq552048250/p/14687658.html

(0)
(0)
   
举报
评论 一句话评论(0
关于我们 - 联系我们 - 留言反馈 - 联系我们:wmxa8@hotmail.com
© 2014 bubuko.com 版权所有
打开技术之扣,分享程序人生!