CutMix原理与代码解读

这篇具有很好参考价值的文章主要介绍了CutMix原理与代码解读。希望对大家有所帮助。如果存在错误或未考虑完全的地方,请大家不吝赐教,您也可以点击"举报违法"按钮提交疑问。

paper:CutMix: Regularization Strategy to Train Strong Classifiers with Localizable Features

前言

之前的数据增强方法存在的问题:

mixup:混合后的图像在局部是模糊和不自然的,因此会混淆模型,尤其是在定位方面。

cutout:被cutout的部分通常用0或者随机噪声填充,这就导致在训练过程中这部分的信息被浪费掉了。

cutmix在cutout的基础上进行改进,cutout的部分用另一张图像上cutout的部分进行填充,这样即保留了cutout的优点:让模型从目标的部分视图去学习目标的特征,让模型更关注那些less discriminative的部分。同时比cutout更高效,cutout的部分用另一张图像的部分进行填充,让模型同时学习两个目标的特征。

从下图可以看出,虽然Mixup和Cutout都提升了模型的分类精度,但在若监督定位和目标检测性能上都有不同程度的下降,而CutMix则在各个任务上都获得了显著的性能提升。

CutMix原理与代码解读

CutMix

cutmix的具体过程如下

CutMix原理与代码解读

其中\(M\in\left \{ 0,1 \right \}^{W\times H}\)是一个binary mask表明从两张图中裁剪的patch的位置,和mixup一样,\(\lambda\)也是通过\(\beta(\alpha, \alpha)\)分布得到的,在文章中作者设置\(\alpha=1\),因此\(\lambda\)是从均匀分布\((0,1)\)中采样的。

为了得到mask,首先要确定cutmix的bounding box的坐标\(B=(r_{x},r_{y},r_{w},r_{h})\),其值通过下式得到

CutMix原理与代码解读

即 \(\lambda\) 确定了patch与原图的面积比,即A图cutout的面积越大,标签融合时A图的比例越小。

代码实现

下面是torchvision的官方实现

class RandomCutmix(torch.nn.Module):
    """Randomly apply Cutmix to the provided batch and targets.
    The class implements the data augmentations as described in the paper
    `"CutMix: Regularization Strategy to Train Strong Classifiers with Localizable Features"
    <https://arxiv.org/abs/1905.04899>`_.

    Args:
        num_classes (int): number of classes used for one-hot encoding.
        p (float): probability of the batch being transformed. Default value is 0.5.
        alpha (float): hyperparameter of the Beta distribution used for cutmix.
            Default value is 1.0.
        inplace (bool): boolean to make this transform inplace. Default set to False.
    """

    def __init__(self, num_classes: int, p: float = 0.5, alpha: float = 1.0, inplace: bool = False) -> None:
        super().__init__()
        if num_classes < 1:
            raise ValueError("Please provide a valid positive value for the num_classes.")
        if alpha <= 0:
            raise ValueError("Alpha param can't be zero.")

        self.num_classes = num_classes
        self.p = p
        self.alpha = alpha
        self.inplace = inplace

    def forward(self, batch: Tensor, target: Tensor) -> Tuple[Tensor, Tensor]:
        """
        Args:
            batch (Tensor): Float tensor of size (B, C, H, W)
            target (Tensor): Integer tensor of size (B, )

        Returns:
            Tensor: Randomly transformed batch.
        """
        if batch.ndim != 4:
            raise ValueError(f"Batch ndim should be 4. Got {batch.ndim}")
        if target.ndim != 1:
            raise ValueError(f"Target ndim should be 1. Got {target.ndim}")
        if not batch.is_floating_point():
            raise TypeError(f"Batch dtype should be a float tensor. Got {batch.dtype}.")
        if target.dtype != torch.int64:
            raise TypeError(f"Target dtype should be torch.int64. Got {target.dtype}")

        if not self.inplace:
            batch = batch.clone()
            target = target.clone()

        if target.ndim == 1:
            target = torch.nn.functional.one_hot(target, num_classes=self.num_classes).to(dtype=batch.dtype)

        if torch.rand(1).item() >= self.p:
            return batch, target

        # It's faster to roll the batch by one instead of shuffling it to create image pairs
        batch_rolled = batch.roll(1, 0)
        target_rolled = target.roll(1, 0)

        # Implemented as on cutmix paper, page 12 (with minor corrections on typos).
        lambda_param = float(torch._sample_dirichlet(torch.tensor([self.alpha, self.alpha]))[0])
        _, H, W = F.get_dimensions(batch)

        r_x = torch.randint(W, (1,))
        r_y = torch.randint(H, (1,))

        r = 0.5 * math.sqrt(1.0 - lambda_param)
        r_w_half = int(r * W)
        r_h_half = int(r * H)

        x1 = int(torch.clamp(r_x - r_w_half, min=0))
        y1 = int(torch.clamp(r_y - r_h_half, min=0))
        x2 = int(torch.clamp(r_x + r_w_half, max=W))
        y2 = int(torch.clamp(r_y + r_h_half, max=H))

        batch[:, :, y1:y2, x1:x2] = batch_rolled[:, :, y1:y2, x1:x2]
        lambda_param = float(1.0 - (x2 - x1) * (y2 - y1) / (W * H))

        target_rolled.mul_(1.0 - lambda_param)
        target.mul_(lambda_param).add_(target_rolled)

        return batch, target

    def __repr__(self) -> str:
        s = (
            f"{self.__class__.__name__}("
            f"num_classes={self.num_classes}"
            f", p={self.p}"
            f", alpha={self.alpha}"
            f", inplace={self.inplace}"
            f")"
        )
        return s

实验结果

从下图可以看出,CutMix在ImageNet上的精度超过了Cutout和Mixup等数据增强方法

CutMix原理与代码解读

在若监督目标定位方面,CutMix也超过了Mixup和Cutout

CutMix原理与代码解读

当作为预训练模型迁移到其它下游任务比如目标检测和图像描述时,CutMix也取得了最好的效果

CutMix原理与代码解读文章来源地址https://www.toymoban.com/news/detail-440574.html

到了这里,关于CutMix原理与代码解读的文章就介绍完了。如果您还想了解更多内容,请在右上角搜索TOY模板网以前的文章或继续浏览下面的相关文章,希望大家以后多多支持TOY模板网!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处: 如若内容造成侵权/违法违规/事实不符,请点击违法举报进行投诉反馈,一经查实,立即删除!

领支付宝红包 赞助服务器费用

相关文章

  • stm32-DHT11原理及代码解读

    目录 一、基础知识 1.功能:温湿度检测 2.应用范围 3.硬件电路连接 二、底层代码原理分析 1.基础知识 1.单总线说明 2.单总线传送数据位定义 3.数据格式 4.校验位数据定义 2.代码分析 1.数据时序图 2.数据传输步骤         测量范围湿度:湿度:5-95%RH        精度:(±

    2023年04月08日
    浏览(34)
  • 一文读懂Stable Diffusion 论文原理+代码超详细解读

    Stable diffusion是一个基于Latent Diffusion Models(LDMs)实现的的文图生成(text-to-image)模型。 2022年8月,游戏设计师Jason Allen凭借AI绘画作品《太空歌剧院(Théâtre D’opéra Spatial)》获得美国科罗拉多州博览会“数字艺术/数码摄影“竞赛单元一等奖,“AI绘画”引发全球热议。得力

    2024年01月19日
    浏览(57)
  • MUSIC算法相关原理知识(物理解读+数学推导+Matlab代码实现)

    部分来自于网络教程,如有侵权请联系本人删除  教程链接:MUSIC算法的直观解释:1,MUSIC算法的背景和基础知识_哔哩哔哩_bilibili  MUSIC算法的直观解释:2,我对于MUSIC算法的理解_哔哩哔哩_bilibili https://blog.csdn.net/zhangziju/article/details/100730081  一、MUSIC算法作用 MUSIC (Multiple

    2024年02月02日
    浏览(40)
  • stm32-CS100A 超声波测距芯片原理及代码解读

            CS100A 是苏州顺憬志联新材料科技有限公司(www.100sensor.com)推出的一款工 业级超声波测距芯片,CS100A 内部集成超声波发射电路,超声波接收电路,数字处理电 路等,单芯片即可完成超声波测距,测距结果通过脉宽的方式进行输出,通信接口兼容 现有超声波模块

    2024年02月04日
    浏览(155)
  • stm32-HY-SRF05 超声波模块-原理及代码解读

    目录 一、基础知识 1.功能:超声波测距 2.硬件介绍及电路连接 二、底层代码原理分析 1基本工作原理 2代码分析 1时序图 步骤1 步骤2       HY-SRF05 超声波测距模块可提供2cm-450cm 的非接触式距离感测功能,测距精度可达高到3mm            VCC 供5V 电源, GND 为地线, TRIG 触

    2024年02月02日
    浏览(54)
  • quality focal loss & distribute focal loss 详解(paper, 代码)

    参见generalized focal loss paper 其中包含有 Quality Focal Loss 和 Distribution Focal Loss 。 dense detectors逐渐引领了目标检测领域的潮流。 目标框的表达方法,localization quality估计方法的改进引起了目标检测的逐渐进步。 其中,目标框表达(坐标或(l,r,t,b))目前被建模为一个简单的Dirac de

    2024年02月06日
    浏览(40)
  • 【2023年五一数学建模竞赛B题】快递需求分析问题--完整paper和代码

    赛题分析:这道题出的比较好,考察面较多,难度循环渐进,相对C题是比较有层次的一道题 请从收货量、发货量、快递数量增长/减少趋势、相关性等多角度考虑,建立数学模型,对各站点城市的重要程度进行综合排序,并给出重要程度排名前5的站点城市名称。 第一问比较

    2024年02月05日
    浏览(52)
  • NILM非侵入式负荷识别(papers with code、data)带代码的论文整理——(论文及实现代码篇) 全网最全

            研究生三年快毕业了,毕业前整理一下该领域的研究工作。正所谓,我栽树,后人乘凉。研究NILM的时候,个人觉得最快的方法是直接复现别人的论文,或者甚至用别人论文的代码直接跑出来体会整个流程(数据集导入-数据预处理-运行模型-输出结果)。研究生三

    2024年02月05日
    浏览(47)
  • 【代码复现系列】paper:CycleGAN and pix2pix in PyTorch

    或许有冗余步骤、之后再优化。 1.桌面右键-git bash-输入命令如下【git clone https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix】 2.打开anaconda的prompt,cd到pytorch-CycleGAN-and-pix2pix路径 3.在prompt里输入【conda env create -f environment.yml】配置虚拟环境及相应的包 4.在prompt里输入【conda activate py

    2024年02月01日
    浏览(34)
  • NILM非侵入式负荷识别(papers with code、data)带代码的论文整理——(公开数据集、工具、和性能指标篇) 全网最全

    Q1:文章里面没有附上代码链接的文章是不是没有源码? Q2:xxx数据集找不到,xxx代码网址打不开了,博主能不能发我一份? 这篇文章主要介绍用于非侵入式负荷识别领域目前的公开数据集、工具和其它等,如果需要看论文及具体代码实现,看我上一篇的文章。 其外, 不是

    2023年04月20日
    浏览(44)

觉得文章有用就打赏一下文章作者

支付宝扫一扫打赏

博客赞助

微信扫一扫打赏

请作者喝杯咖啡吧~博客赞助

支付宝扫一扫领取红包,优惠每天领

二维码1

领取红包

二维码2

领红包