Focal Loss论文解读和调参教程

这篇具有很好参考价值的文章主要介绍了Focal Loss论文解读和调参教程。希望对大家有所帮助。如果存在错误或未考虑完全的地方,请大家不吝赐教,您也可以点击"举报违法"按钮提交疑问。

论文:Focal Loss for Dense Object Detection

论文papar地址:ICCV 2017 Open Access Repository

在各个主流深度学习框架里基本都有实现,本文会以mmcv里的focal loss实现为例(基于pytorch)

简介:

本文是何恺明团队ICCV 2017的一篇文章,主要针对检测场景类别不均衡导致一阶段算法没有二阶段算法精度高,在CE loss的基础上进行改进,提出了Focal Loss,并且本文改动了faster rcnn,魔改成了一个一阶段的算法RetinaNet,也是后续很多工作拿来当baseline的anchor-based一阶段算法。

动机是作者认为,一阶段和二阶段算法的精度差距,主要原因是一阶段基本都是dense detect(指采样的区域很密集,简而言之就是anchor box/proposal很多),而二阶段的算法是精选出高质量的样本(比如RPN、selective search),在二阶段产生相对较少的ROI进行回归和分类预测。一阶段产生那么多anchor ,但是其中只有一小部分变成最后预测的bbox result,因此会有很多易分类负样本在loss function里占很大的比重,就会不利于训练。也就是说Focal Loss的贡献就是缓解了类别不平衡问题(注意:这里的类别不平衡不单单是指正负样本数量的不平衡,还有难易样本数量的不平衡)。

Focal Loss具体原理

修改是基于CE loss的(因此focal loss是分类的loss,当然也用于检测框的分类,只是跟回归无关),首先为正样本加入权重因子α,这样的操作一般叫Balanced Cross Entropy,为了解决正负样本不平衡对损失函数造成的影响。

最原本的CE loss(cross entroy loss交叉熵损失函数)形式如下:

Focal Loss论文解读和调参教程

为了解决正负样本不平衡问题(负样本太多,正样本太少),一个nature的思路就是给正负样本添加权重alpha,用来减小负样本的占比影响,

Focal Loss论文解读和调参教程

 显然alpha越大,正样本的loss占比越大!即α设置的越大,负样本对loss的影响越小。这样就解决了正负样本数量不平衡对最后整个loss函数造成的影响。

下面解决难易样本数量不平衡:在训练时,易分样本数量远大于难分样本数量,易分样本指的是:target为正样本,且pred得分(检测框的score)高,即易分正样本;target为负样本,且pred得分低,即易分负样本

为此我们再引入一个权重gamma,用来减小易分样本的占比影响

Focal Loss论文解读和调参教程

 至此,只需要组合上面的α和γ,就得到了Focal Loss的最终形式:

Focal Loss论文解读和调参教程

 这种分类loss既能够缓解正负样本数量不均衡的问题,也能缓解难易样本数量不均衡问题,只引入了两个超参数。

值得一提的是,作者在原文中通过实验证明,在COCO数据集上,α取0.25,γ取2的组合精度最高。

RetinaNet

因为这篇文章里提出了一个比较著名的网络RetinaNet,因此顺便也介绍下。

RetinaNet是一个一阶段的网络,由一个主干网络和两个特定于任务(目标检测)的两个子网络(其实就是一个分类头+一个回归头)。

Focal Loss论文解读和调参教程

作者用这个很简单的retinanet当做一个一阶段算法的baseline,通过在上面用focal loss超越了二阶段的faster rcnn精度,同时又保留了一阶段的高效率。以此来证明一阶段和二阶段的算法精度差距确实就在于作者提出的类别不平衡猜想。

mmcv中focal loss实现源码和调参

这里首先提示一句,一般看到的二阶段算法的cls_loss都是最基础的CE loss,因为二阶段已经有成熟的RPN,因此生成的anchor或者说proposal的类别不均衡问题不严重,因此没必要用focal loss。

这里就以mmdet里的focal loss实现为例,源码位置在mmdet\models\losses\focal_loss.py

class FocalLoss(nn.Module):

    def __init__(self,
                 use_sigmoid=True,
                 gamma=2.0,
                 alpha=0.25,
                 reduction='mean',
                 loss_weight=1.0,
                 activated=False):
        """`Focal Loss <https://arxiv.org/abs/1708.02002>`_

        Args:
            use_sigmoid (bool, optional): Whether to the prediction is
                used for sigmoid or softmax. Defaults to True.
            gamma (float, optional): The gamma for calculating the modulating
                factor. Defaults to 2.0.
            alpha (float, optional): A balanced form for Focal Loss.
                Defaults to 0.25.
            reduction (str, optional): The method used to reduce the loss into
                a scalar. Defaults to 'mean'. Options are "none", "mean" and
                "sum".
            loss_weight (float, optional): Weight of loss. Defaults to 1.0.
            activated (bool, optional): Whether the input is activated.
                If True, it means the input has been activated and can be
                treated as probabilities. Else, it should be treated as logits.
                Defaults to False.
        """
        super(FocalLoss, self).__init__()
        assert use_sigmoid is True, 'Only sigmoid focal loss supported now.'
        self.use_sigmoid = use_sigmoid
        self.gamma = gamma
        self.alpha = alpha
        self.reduction = reduction
        self.loss_weight = loss_weight
        self.activated = activated

    def forward(self,
                pred,
                target,
                weight=None,
                avg_factor=None,
                reduction_override=None):
        """Forward function.

        Args:
            pred (torch.Tensor): The prediction.
            target (torch.Tensor): The learning label of the prediction.
            weight (torch.Tensor, optional): The weight of loss for each
                prediction. Defaults to None.
            avg_factor (int, optional): Average factor that is used to average
                the loss. Defaults to None.
            reduction_override (str, optional): The reduction method used to
                override the original reduction method of the loss.
                Options are "none", "mean" and "sum".

        Returns:
            torch.Tensor: The calculated loss
        """
        assert reduction_override in (None, 'none', 'mean', 'sum')
        reduction = (
            reduction_override if reduction_override else self.reduction)
        if self.use_sigmoid:
            if self.activated:
                calculate_loss_func = py_focal_loss_with_prob
            else:
                if torch.cuda.is_available() and pred.is_cuda:
                    calculate_loss_func = sigmoid_focal_loss
                else:
                    num_classes = pred.size(1)
                    target = F.one_hot(target, num_classes=num_classes + 1)
                    target = target[:, :num_classes]
                    calculate_loss_func = py_sigmoid_focal_loss

            loss_cls = self.loss_weight * calculate_loss_func(
                pred,
                target,
                weight,
                gamma=self.gamma,
                alpha=self.alpha,
                reduction=reduction,
                avg_factor=avg_factor)

        else:
            raise NotImplementedError
        return loss_cls

可以看到只需要在init这个loss的时候赋予gamma和alpha就可以,比如我改变我的htc算法config里的

loss_cls=dict(
    type='CrossEntropyLoss',
    use_sigmoid=False,
    loss_weight=1.0),

改成

loss_cls=dict(
    type='FocalLoss),

即可,用的alpha和gamma都是论文里默认的“最优决策”:α=0.25,γ=2.0

当然这两个超参数要根据你实际的数据集和任务场景调整。文章来源地址https://www.toymoban.com/news/detail-415560.html

到了这里,关于Focal Loss论文解读和调参教程的文章就介绍完了。如果您还想了解更多内容,请在右上角搜索TOY模板网以前的文章或继续浏览下面的相关文章,希望大家以后多多支持TOY模板网!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处: 如若内容造成侵权/违法违规/事实不符,请点击违法举报进行投诉反馈,一经查实,立即删除!

领支付宝红包 赞助服务器费用

相关文章

  • quality focal loss & distribute focal loss 解说(附代码)

    参见generalized focal loss paper 其中包含有 Quality Focal Loss 和 Distribution Focal Loss 。 dense detectors逐渐引领了目标检测领域的潮流。 目标框的表达方法,localization quality估计方法的改进引起了目标检测的逐渐进步。 其中,目标框表达(坐标或(l,r,t,b))目前被建模为一个简单的Dirac de

    2023年04月23日
    浏览(66)
  • quality focal loss & distribute focal loss 详解(paper, 代码)

    参见generalized focal loss paper 其中包含有 Quality Focal Loss 和 Distribution Focal Loss 。 dense detectors逐渐引领了目标检测领域的潮流。 目标框的表达方法,localization quality估计方法的改进引起了目标检测的逐渐进步。 其中,目标框表达(坐标或(l,r,t,b))目前被建模为一个简单的Dirac de

    2024年02月06日
    浏览(36)
  • MTN模型LOSS均衡相关论文解读

    目录 一、综述 二、依据任务不确定性加权多任务损失  三、依据不同任务的梯度大小来动态修正其loss权重GradNorm 四、根据LOSS变化动态均衡任务权重Dynamic Weight Average(DWA) 五、Reference  MTN模型主要用于两个方面,1.将多个模型合为一个显著降低车载芯片负载。2.将多个任务模

    2024年02月04日
    浏览(36)
  • 《Dense Distinct Query for End-to-End Object Detection》论文笔记(ing)

    作者这里认为传统个目标检测的anchor/anchorpoint其实跟detr中的query作用一样,可以看作query (1)dense query:传统目标检测生成一堆密集anchor,但是one to many需要NMS去除重复框,无法end to end。 (2)spare query 在one2one:egDETR,100个qeury,数量太少造成稀疏监督,收敛慢召回率低。 (

    2024年01月25日
    浏览(44)
  • Focal Loss介绍

      在目标检测算法中,我们会经常遇到Focal Loss这个东西,今天我们就来简单的分下下这个损失。   在深度学习训练的时候,在遇到目标类别不平衡时,如果直接计算损失函数,那么最终计算的结果可能会偏向于常见类别,低召回率,模型过拟合等问题。为了应对这个问

    2024年02月09日
    浏览(34)
  • pytorch如何使用Focal Loss

    Focal loss 是 文章 Focal Loss for Dense Object Detection 中提出对简单样本的进行 decay 的一种损失函数。是对标准的 Cross Entropy Loss 的一种改进。 FL 对于简单样本(p比较大)回应较小的loss。 如论文中的图1, 在p=0.6时, 标准的 CE 然后又较大的 loss , 但是对于FL就有相对较小的loss回应

    2024年02月10日
    浏览(37)
  • 交叉熵、Focal Loss以及其Pytorch实现

    本文参考链接:https://towardsdatascience.com/focal-loss-a-better-alternative-for-cross-entropy-1d073d92d075 损失是通过梯度回传用来更新网络参数是之产生的预测结果和真实值之间相似。不同损失函数有着不同的约束作用,不同的数据对损失函数有着不同的影响。 交叉熵是常见的损失函数,常

    2024年02月11日
    浏览(58)
  • Focal Loss:类别不平衡的解决方案

    ❤️觉得内容不错的话,欢迎点赞收藏加关注😊😊😊,后续会继续输入更多优质内容❤️ 👉有问题欢迎大家加关注私戳或者评论(包括但不限于NLP算法相关,linux学习相关,读研读博相关......)👈 (封面图由ERNIE-ViLG AI 作画大模型生成) 在目标检测领域,常常使用交叉熵

    2024年02月06日
    浏览(65)
  • EIoU和Focal-EIoU Loss

    论文题目:《Focal and Efficient IOU Loss for Accurate Bounding Box Regression》 CIoU Loss虽然考虑了边界框回归的重叠面积、中心点距离、高宽比。但是其公式中的v反映的是高宽的差异,而不是高宽分别与其置信度的真实差异。因此,有时会阻碍模型有效的优化相似性。针对这一问题,本文

    2024年03月27日
    浏览(45)
  • Perceptual Loss(感知损失)&Perceptual Losses for Real-Time Style Transferand Super-Resolution论文解读

    由于传统的L1,L2 loss是针对于像素级的损失计算,且L2 loss与人眼感知的图像质量并不匹配,单一使用L1或L2 loss对于超分等任务来说恢复出来的图像往往细节表现都不好。 现在的研究中,L2 loss逐步被人眼感知loss所取代。人眼感知loss也被称为perceptual loss(感知损失),它与MSE(

    2023年04月20日
    浏览(47)

觉得文章有用就打赏一下文章作者

支付宝扫一扫打赏

博客赞助

微信扫一扫打赏

请作者喝杯咖啡吧~博客赞助

支付宝扫一扫领取红包,优惠每天领

二维码1

领取红包

二维码2

领红包