首页 > 其他 > 详细

PyTorch(三)Loss Function

时间:2020-05-20 16:27:43      阅读:123      评论:0      收藏:0      [点我收藏+]

以一个简单例子来说明各个 Loss 函数的使用

label_numpy = np.array([[0, 0, 0, 0], [0, 1, 1, 0], [0, 1, 1, 0], [0, 0, 0, 1]], dtype=np.float) # 模拟 标签
out_numpy = np.array([[0, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1]], dtype=np.float) # 模拟 预测
num_classes = 2

nn.BCELoss

$l_n = -w_n[y_n * log{x_n} + (1 - y_n) * log(1 - x_n)]$

label = torch.from_numpy(label_numpy).unsqueeze(0) # N x C
output = torch.from_numpy(out_numpy).unsqueeze(0)  # N x C
# ======================================================= #
criterion = nn.BCELoss()
loss = criterion(F.sigmoid(output), label) # 0.6219
# ======================================================= #

nn.BCEWithLogitsLoss

label = torch.from_numpy(label_numpy).unsqueeze(0)
output = torch.from_numpy(out_numpy).unsqueeze(0)
# ======================================================= #
criterion = nn.BCEWithLogitsLoss()
loss = criterion(output, label) # 0.6219
# ======================================================= #

这个损失将Sigmoid层和BCELoss合并在一个类中,且数值稳定性更好

具体计算过程如下

技术分享图片
class BCEWithLogitsLoss(nn.Module):
    """
    这个版本在数值上比使用一个简单的Sigmoid和一个BCELoss as更稳定,通过将操作合并到一个层中,我们利用log-sum-exp技巧来实现数值稳定性。
    """
    def __init__(self):
        super(BCEWithLogitsLoss, self).__init__()

    def forward(self, input, target, weight=None, size_average=None,
                reduce=None, reduction=mean, pos_weight=None):
        if size_average is not None or reduce is not None:
            reduction = _Reduction.legacy_get_string(size_average, reduce)
        if not (target.size() == input.size()):
            raise ValueError("Target size ({}) must be the same as input size ({})".format(target.size(), input.size()))
        max_val = (-input).clamp(min=0)
        if pos_weight is None:
            loss = input - input * target + max_val + ((-max_val).exp() + (-input - max_val).exp()).log()
        else:
            log_weight = 1 + (pos_weight - 1) * target
            loss = input - input * target + log_weight * (max_val + ((-max_val).exp() + (-input - max_val).exp()).log())

        if weight is not None:
            loss = loss * weight
        if reduction == none:
            return loss
        elif reduction == mean:
            return loss.mean()
        else:
            return loss.sum()
View Code

nn.CrossEntropyLoss

这里,输出为 one-hot 格式

$loss(x, class) = -log(\frac{e^{x[class]}}{\sum_j{e^{x[j]}}}) = -x[class] + log(\sum_j{e^{x[j]}})$

label_numpy = label_numpy.reshape((label_numpy.size))
out_numpy = out_numpy.reshape((label_numpy.size))
label = torch.from_numpy(label_numpy).long()
onehot_output = np.eye(num_classes)[np.where(out_numpy>=0.5, 1, 0)] # convert to onehot format
output = torch.from_numpy(onehot_output)
# ======================================================= #
criterion = nn.CrossEntropyLoss()
loss = criterion(output, label) # 0.4383
# ======================================================= #

具体计算过程如下:

技术分享图片
first = np.zeros(shape=(output.shape[0]))
for i in range(output.shape[0]):
    first[i] = -output[i][label[i]]
second = np.zeros(shape=(output.shape[0]))
for i in range(output.shape[0]):
    for j in range(output.shape[1]):
        second[i] += np.exp(output[i][j])

res = (first + np.log(second)).mean()
View Code

 

 

 

 

 

PyTorch(三)Loss Function

原文:https://www.cnblogs.com/xuanyuyt/p/12923519.html

(0)
(0)
   
举报
评论 一句话评论(0
关于我们 - 联系我们 - 留言反馈 - 联系我们:wmxa8@hotmail.com
© 2014 bubuko.com 版权所有
打开技术之扣,分享程序人生!