【Focal Net】NeuralPS2022 论文+代码解读 Focal Modulation Networks

这篇具有很好参考价值的文章主要介绍了【Focal Net】NeuralPS2022 论文+代码解读 Focal Modulation Networks。希望对大家有所帮助。如果存在错误或未考虑完全的地方,请大家不吝赐教,您也可以点击"举报违法"按钮提交疑问。

Focal Modulation Networks

1. 资源链接

论文链接-arxiv
官方代码-github
官方博客-microsoft

2. 摘要

该论文提出了一个Focal Modulaiton network,将自注意力机制替换成focal modulation(聚焦调制)。这种机制包括3个组件:1)通过depth-wise Conv提取分级的上下文信息,同时编码短期和长期依赖。2)门控聚合,基于每个token的内容选择性的聚集视觉上下文。3)通过点乘或者仿射变换将汇聚的上下文信息注入query。
Focal Net 主要是在block中加入了Mulit-level 的特征融合机制,类似于目标检测中很常见的 FPN结构,同时学习粗粒度的空间信息和细粒度的特征信息,提高网络的性能。该网络做为新型的backbone,在分类,分割,目标检测,实例分割等任务上都取得了非常好的效果,尤其是基于DETR 框架的检测算法在COCO上取得了新的SOTA结果。
【Focal Net】NeuralPS2022 论文+代码解读 Focal Modulation Networks

3. 结果

从下图可以看到,基于FocalNet的检测算法模型相对较小,训练数据也比较少,性能却有提高。
【Focal Net】NeuralPS2022 论文+代码解读 Focal Modulation Networks

4. 特征可视化

从门控信号可视化可以看到,正如文章宣称的,不同level的特征可以注意到图像中不同的区域,包括图像局部特征区域和全局空间信息。
【Focal Net】NeuralPS2022 论文+代码解读 Focal Modulation Networks

6. 和自注意力机制对比

6.1 定性分析

相比于自监督,FocalModulation 的输出关注了多尺度的上下文,算子更轻量化。
【Focal Net】NeuralPS2022 论文+代码解读 Focal Modulation Networks

6.2 结构差异

自注意力中,key和qury是密集的矩阵相乘,Attention也是和value的密集矩阵乘积。而FocalNet中分别采用Depth-Wise Conv和Point-Wise Conv,计算更轻量化。
【Focal Net】NeuralPS2022 论文+代码解读 Focal Modulation Networks

7核心代码

Focal Modulation代码

class FocalModulation(nn.Module):
    def __init__(self,
                 dim,
                 focal_window,
                 focal_level,
                 focal_factor=2,
                 bias=True,
                 proj_drop=0.,
                 use_postln=False):
        super().__init__()
        self.dim = dim
        self.focal_window = focal_window
        self.focal_level = focal_level
        self.focal_factor = focal_factor
        self.use_postln = use_postln

        self.f = nn.Linear(dim, 2 * dim + (self.focal_level + 1), bias=bias) 
        self.h = nn.Conv2d(dim, dim, kernel_size=1, stride=1, bias=bias)  #1x1 卷积

        self.act = nn.GELU()
        self.proj = nn.Linear(dim, dim)  # output_dim=input_dim
        self.proj_drop = nn.Dropout(proj_drop)
        self.focal_layers = nn.ModuleList()
        self.kernel_sizes = []
        for k in range(self.focal_level):  # Hierarchical Context
            kernel_size = self.focal_factor * k + self.focal_window  # 多尺度kenel_size
            self.focal_layers.append(
                nn.Sequential(
                    nn.Conv2d(dim,
                              dim,
                              kernel_size=kernel_size,
                              stride=1,
                              groups=dim,  # groups==input_dim  depth-wise conv
                              padding=kernel_size // 2,
                              bias=False),
                    nn.GELU(),
                ))
            self.kernel_sizes.append(kernel_size)
        if self.use_postln:
            self.ln = nn.LayerNorm(dim)

    def forward(self, x):
        """
        Args:
            x: input features with shape of (B, H, W, C)
        """
        C = x.shape[-1]

        # pre linear projection
        x = self.f(x).permute(0, 3, 1, 2).contiguous()
        q, ctx, self.gates = torch.split(x, (C, C, self.focal_level + 1), 1)

        # context aggreation
        ctx_all = 0
        for l in range(self.focal_level):
            ctx = self.focal_layers[l](ctx)
            ctx_all = ctx_all + ctx * self.gates[:, l:l + 1]  #Gated Aggregation
        ctx_global = self.act(ctx.mean(2, keepdim=True).mean(3, keepdim=True)) # AvgPool 
        ctx_all = ctx_all + ctx_global * self.gates[:, self.focal_level:]  # 在局部 Context 上加入 全局Context

        # focal modulation
        self.modulator = self.h(ctx_all)
        x_out = q * self.modulator
        x_out = x_out.permute(0, 2, 3, 1).contiguous()
        if self.use_postln:
            x_out = self.ln(x_out)
        # post linear porjection
        x_out = self.proj(x_out)
        x_out = self.proj_drop(x_out)
        return x_out

Self-Attention 代码文章来源地址https://www.toymoban.com/news/detail-465250.html

def scaled_dot_product(q, k, v, mask=None):
    d_k = q.size()[-1]
    attn_logits = torch.matmul(q, k.transpose(-2, -1))
    attn_logits = attn_logits / math.sqrt(d_k)
    if mask is not None:
        attn_logits = attn_logits.masked_fill(mask == 0, -9e15)
    attention = F.softmax(attn_logits, dim=-1)
    values = torch.matmul(attention, v)
    return values, attention

到了这里,关于【Focal Net】NeuralPS2022 论文+代码解读 Focal Modulation Networks的文章就介绍完了。如果您还想了解更多内容,请在右上角搜索TOY模板网以前的文章或继续浏览下面的相关文章,希望大家以后多多支持TOY模板网!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处: 如若内容造成侵权/违法违规/事实不符,请点击违法举报进行投诉反馈,一经查实,立即删除!

领支付宝红包 赞助服务器费用

相关文章

  • centerpoint论文和代码解读

      目录 一、序论 二、论文结构 三、代码 论文地址: https://arxiv.org/pdf/2006.11275.pdf  代码地址:tianweiy/CenterPoint (github.com) centorpoint是一种anchor-free的方法,直接预测物体的中心点,然后直接回归其whl,省去了anchor与GT匹配过程(传统的anchor-base方法需要计算GT和anchor的iou进行分配

    2024年02月12日
    浏览(27)
  • [论文阅读笔记18] DiffusionDet论文笔记与代码解读

    扩散模型近期在图像生成领域很火, 没想到很快就被用在了检测上. 打算对这篇论文做一个笔记. 论文地址: 论文 代码: 代码 首先介绍什么是扩散模型. 我们考虑生成任务, 即encoder-decoder形式的模型, encoder提取输入的抽象信息, 并尝试在decoder中恢复出来. 扩散模型就是这一类中的

    2023年04月08日
    浏览(43)
  • 【论文解读】用于代码处理的语言模型综述

    目录 1.简要介绍 2.代码处理的语言模型的评估 3.通用语言模型 4.用于代码处理的特定语言模型 5.语言模型的代码特性 6.软件开发中的LLM 7.结论与挑战 ​​​​​​​ 1.简要介绍 在这项工作中,论文系统地回顾了在代码处理方面的最新进展,包括50个+模型,30个+评估任务和5

    2024年01月18日
    浏览(31)
  • ResNet论文解读及代码实现(pytorch)

    又重新看了一遍何凯明大神的残差网络,之前懵懵懂懂的知识豁然开朗了起来。然后,虽然现在CSDN和知乎的风气不是太好,都是一些复制粘贴别人的作品来给自己的博客提高阅读量的人,但是也可以从其中汲取到很多有用的知识,我们要取其精华,弃其糟粕。 我只是大概的

    2024年02月04日
    浏览(37)
  • quality focal loss & distribute focal loss 解说(附代码)

    参见generalized focal loss paper 其中包含有 Quality Focal Loss 和 Distribution Focal Loss 。 dense detectors逐渐引领了目标检测领域的潮流。 目标框的表达方法,localization quality估计方法的改进引起了目标检测的逐渐进步。 其中,目标框表达(坐标或(l,r,t,b))目前被建模为一个简单的Dirac de

    2023年04月23日
    浏览(52)
  • 3D目标检测--PointPillars论文和OpenPCDet代码解读

    解决传统基于栅格化的3D目标检测方法在面对高密度点云数据时的性能瓶颈; 栅格化方法需要将点云数据映射到规则的网格中,但是对于高密度点云,栅格化操作会导致严重的信息损失和运算效率低下; 因此,该论文提出了一种新的基于点云的3D目标检测方法——PointPillars,

    2023年04月22日
    浏览(68)
  • 张正友标定论文的解读和C++代码编写

    张正友标定相机内参是非常经典的标定算法,现在代码已经被集成到MATLAB和opencv里面。不过因为算法涉及到基础的相机坐标系、图像坐标系、公式推导,以及优化算法,故根据张正友论文进行分模块代码编写。 https://github.com/Shelfcol/Zhangzhengyou_calib_cam_intrinsic 此C++代码是根据张

    2024年02月05日
    浏览(55)
  • quality focal loss & distribute focal loss 详解(paper, 代码)

    参见generalized focal loss paper 其中包含有 Quality Focal Loss 和 Distribution Focal Loss 。 dense detectors逐渐引领了目标检测领域的潮流。 目标框的表达方法,localization quality估计方法的改进引起了目标检测的逐渐进步。 其中,目标框表达(坐标或(l,r,t,b))目前被建模为一个简单的Dirac de

    2024年02月06日
    浏览(28)
  • 机器视觉 多模态学习11篇经典论文代码以及解读

    此处整理了深度学习-机器视觉,最新的发展方向-多模态学习,中的11篇经典论文,整理了相关解读博客和对应的Github代码,看完此系列论文和博客,相信你能快速切入这个方向。每篇论文、博客或代码都有相关标签,一目了然,整理到这里了 webhub123 机器视觉 多模态学习

    2024年02月13日
    浏览(28)
  • 【语义分割】ST_Unet论文 逐步代码解读

    主要工程文件为这5个 分别作用为: 构造相应的deform 卷积 DCNN的残差网络 编写相应的配置文件,可以改变相应参数 模型的主函数和主框架 模型的连接部分 代码框架由3部分组成,encode,decode和decode中将图像还原成语义分割预测图 Transformer(config, img_size) 组成编码部分,包含主

    2024年02月07日
    浏览(32)

觉得文章有用就打赏一下文章作者

支付宝扫一扫打赏

博客赞助

微信扫一扫打赏

请作者喝杯咖啡吧~博客赞助

支付宝扫一扫领取红包,优惠每天领

二维码1

领取红包

二维码2

领红包