Focal Loss[1]是一种用来处理单阶段目标检测器训练过程中出现的正负、难易样本不平衡问题的方法。关于Focal Loss,[2]中已经讲的很详细了,这篇博客主要是记录和补充一些细节。
论文中没有用一般多分类任务采取的softmax loss,而是使用了多标签分类中的sigmoid loss,原因是sigmoid的形式训练过程中会更稳定。因此RetinaNet分类subnet输出的通道数是 KA 而不是 (K+1)A(K为类别数,A为每个cell铺的anchor数)。
MMDetection[3]中实现的Focal Loss如下:
# This method is only for debugging
def py_sigmoid_focal_loss(pred,
target,
weight=None,
gamma=2.0,
alpha=0.25,
reduction=‘mean‘,
avg_factor=None):
pred_sigmoid = pred.sigmoid()
target = target.type_as(pred)
pt = (1 - pred_sigmoid) * target + pred_sigmoid * (1 - target)
focal_weight = (alpha * target + (1 - alpha) *
(1 - target)) * pt.pow(gamma)
loss = F.binary_cross_entropy_with_logits(
pred, target, reduction=‘none‘) * focal_weight
loss = weight_reduce_loss(loss, weight, reduction, avg_factor)
return loss
论文中给出的公式是:\(FL(p_t)=-\alpha_t(1-p_t)^\gamma\log(p_t)\),下面分析代码的逻辑:
首先给出两个公式:
\(p_t\)为预测值,表示属于 t 类别的概率,可统一表示为:\(p_t=p*t+(1-p)*(1-t)\)
\(\alpha_t\)为权重参数,表示属于 t 类别的权重,可统一表示为:\(\alpha_t=\alpha*t+(1-\alpha)*(1-t)\)
带入得:\(FL(p_t)=-\alpha_t(1-(p*t+(1-p)*(1-t)))^\gamma\log(p_t)\)
举一个例子,设
pred=[0.1, 0.3, 0.8, 0.1, 0.1]
target=[0, 0, 1, 0, 0]
α=0.25
γ=2
# sigmoid value of pred
pred_sigmoid=[0.5250, 0.5744, 0.6900, 0.5250, 0.5250]
直接根据论文的公式计算loss可得:\(FL(pred,target)=\underbrace{-0.75*0.5250^2*\log(1-0.5250)*3 - 0.75*0.5744^2*\log(1-0.5744)}_{negatives}\underbrace{-0.25*(1-0.6900)^2*log(0.6900)}_{positives}=0.1364\)
与上面的py_sigmoid_focal_loss
函数(计算的是平均值)计算结果相同。
因为Focal Loss的本意是将loss集中在正样本上,所以我一直以为α=0.25是负样本的权重,但是调试代码时发现0.25其实是乘在正样本上了。这是一个比较矛盾的地方,因为检测任务中负样本比正样本要多很多,而且大部分都是论文中提到过的easy negatives。自然的想法当然是降低这部分loss的权重,让训练朝着更有意义的方向进行,所以我们给正样本的α设大一点,负样本是1-α,因此会比较小。直到看到[2]评论区的讨论,个人觉得还是比较有说服力的:
重新去查了下focal loss论文,在gamma=0时,alpha=0.75效果更好,但当gamma=2时,alpha=0.25效果更好,个人的解释为负样本(IOU<=0.5)虽然远比正样本(IOU>0.5)要多,但大部分为IOU很小(如<0.1)以至于在gamma作用后某种程度上贡献较大损失的负样本甚至比正样本还要少,所以alpha=0.25要反过来重新平衡负正样本。
大意就是负样本大部分都是容易检测的,用于平衡难易样本地γ取2时,负样本的loss会过度地衰减,因此需要α进行反向地平衡。我没有用代码验证过,不过这些都是超参,研究的意义也不大,定性地分析应该足够。
mmdetection的py_sigmoid_focal_loss
实现其实有一点问题,不能直接替换sigmoid_focal_loss
,不过最近已经修改过了,这部分以后有机会再细说。
MMDetection Sigmoid Focal Loss解析
原文:https://www.cnblogs.com/southtonorth/p/14336845.html