GCNet: Global Context Network(ICCV 2019)原理与代码解析

这篇具有很好参考价值的文章主要介绍了GCNet: Global Context Network(ICCV 2019)原理与代码解析。希望对大家有所帮助。如果存在错误或未考虑完全的地方,请大家不吝赐教,您也可以点击"举报违法"按钮提交疑问。

paper:GCNet: Non-local Networks Meet Squeeze-Excitation Networks and Beyond

official implementaion:https://github.com/xvjiarui/GCNet

Third party implementation:https://github.com/open-mmlab/mmcv/blob/master/mmcv/cnn/bricks/context_block.py

存在的问题

通过捕获long-range dependency提取全局信息,对各种视觉任务都是很有帮助的。Non-local Network(介绍见https://blog.csdn.net/ooooocj/article/details/124573078)通过自注意力机制来解决这个问题。对于每个查询位置(query position),non-local network首先计算该位置和所有位置之间一个两两成对的关系,得到一个attention map。然后对attention map所有位置的权重加权求和得到汇总特征,每一个查询位置都得到一个汇总特征,将汇总特征与原始特征相加得到最终输出。

对于某个query position,non-local network计算的另一个位置与该位置的关系即一个权重值表示这个位置对query位置的重要程度。本文可视化attention map发现,不同的query位置其对应的attention map几乎一样,如下图所示

gcnet代码,Attention,深度学习,人工智能,注意力机制,Powered by 金山文档

non-local block可以表示为下式

gcnet代码,Attention,深度学习,人工智能,注意力机制,Powered by 金山文档

其中 \(i\) 是query position的索引,\(j\) 遍历所有位置,\(f(\mathbf{x}_{i},\mathbf{x}_{j})\) 表示位置 \(i\) 和 \(j\) 之间的关系,\(\mathcal{C}(\mathbf{x})\) 是归一化因子,\(W_{z}\) 和 \(W_{v}\) 是线性变换矩阵例如 \(1\times1\) 卷积。non-local block有多种不同的实例化方法,例如Gaussian、Embedded Gaussian、Dot product、Concat,下图(a)是Embedded Gaussian的结构。

gcnet代码,Attention,深度学习,人工智能,注意力机制,Powered by 金山文档

由于需要计算每个query位置的attention map,因此non-local block的时间和空间复杂度都是所有位置的平方关系。

下图是作者从COCO数据集中随机挑选的6张图片,并可视化出3个不同的query position即图中的红点与对应的query-specific attention map,可以看出对于不同的查询位置,它们的attention map几乎是相同的。

gcnet代码,Attention,深度学习,人工智能,注意力机制,Powered by 金山文档

为了进一步验证这一观察结果,作者又分析了不同的查询位置与全局上下文之间的距离。结果如下表所示。其中计算了三种向量之间的余弦距离cosine distance,分别为non-local block的输入、输出、以及查询的注意力图,对应表中的input、output、att。

gcnet代码,Attention,深度学习,人工智能,注意力机制,Powered by 金山文档

从表中可以看出,input列的余弦距离比较大表明non-local的输入特征可以在不同的位置进行区分,但output列的余弦距离非常小,表明non-local block建模的全局上下文特征对于不同的query position几乎是相同的。attention map上的距离非常小,也验证了可视化的观察结果。

尽管non-local block是打算针对每个位置计算全局上下文的,但训练后的全局上下文实际上是独立于查询位置的。因为没有必要为每个查询位置单独计算query-specific全局上下文。

本文的创新点

本文通过观察发现non-local block针对每个query position计算的attention map最终结果是独立于查询位置的,那么就没有必要针对每个查询位置计算了,因此提出计算一个通用的attention map并应用于输入feature map上的所有位置,大大减少了计算量的同时又没有导致性能的降低。此外,结合SE block,设计了一个新的Global Context (GC) block,既轻量又可以有效地建模全局上下文。GC Block结合了Non-local block和SE block的优点,基于GC Block设计的GCNet在多个任务上均超过了NLNet和SENet。

方法介绍

作者舍去了式(1)中的 \(W_{z}\) 即图(3)(a)中的query分支,得到下式

gcnet代码,Attention,深度学习,人工智能,注意力机制,Powered by 金山文档

这里采用的是最常用的Embedded Gaussiian的实例化方式。简化后的non-local block如图(3)(b)所示。

为了进一步减少计算量,作者应用分配率将 \(W_{v}\) 移到attention pool的外面,如下

gcnet代码,Attention,深度学习,人工智能,注意力机制,Powered by 金山文档

这里简化后的non-local block如图4(b)所示。1x1卷积 \(W_{v}\) 的FLOPs从 \(\mathcal{O}(HWC^{2})\) 减小到 \(\mathcal{O}(C^{2})\)。

gcnet代码,Attention,深度学习,人工智能,注意力机制,Powered by 金山文档

到目前为止简化的NL block中参数量最大的部分在transform module即图4(b)中的Transform部分,这里是一个1x1卷积但参数量为 \(C\cdot C\),当把Nl block应用到较深的层例如resnet中的 \(res_{5}\) 时,CxC=2028x2048,占据了整个block大部分的计算量。为了进一步减少计算量,作者借鉴了SE block的思想如图4(c),将图4(b)中的transform module换成了图4(d)中的bottleneck transform module,其中 \(r\) 是reduction ratio,这样参数量就从C·C变成了2·C·C/r,默认情况下r=16,因此参数量就减少为1/8。

实验结果

下表baseline是backbone为ResNet-50的Mask R-CNN在COCO数据集上的目标检测和实例分割的结果。将1个non-local block(NL)、1个simplified non-local block(SNL)、1个global context block(GC)插入到c4的最后一个residual block前,可以看出GC block获得的相似的性能但参数量更小。将GC block添加到所有的residual block中在参数量相似的情况下得到了更高的性能。

gcnet代码,Attention,深度学习,人工智能,注意力机制,Powered by 金山文档

代码解析

这里的代码是mmcv中的实现。文章来源地址https://www.toymoban.com/news/detail-650469.html

# Copyright (c) OpenMMLab. All rights reserved.
from typing import Union

import torch
from torch import nn

from ..utils import constant_init, kaiming_init
from .registry import PLUGIN_LAYERS


def last_zero_init(m: Union[nn.Module, nn.Sequential]) -> None:
    if isinstance(m, nn.Sequential):
        constant_init(m[-1], val=0)
    else:
        constant_init(m, val=0)


@PLUGIN_LAYERS.register_module()
class ContextBlock(nn.Module):
    """ContextBlock module in GCNet.

    See 'GCNet: Non-local Networks Meet Squeeze-Excitation Networks and Beyond'
    (https://arxiv.org/abs/1904.11492) for details.

    Args:
        in_channels (int): Channels of the input feature map.
        ratio (float): Ratio of channels of transform bottleneck
        pooling_type (str): Pooling method for context modeling.
            Options are 'att' and 'avg', stand for attention pooling and
            average pooling respectively. Default: 'att'.
        fusion_types (Sequence[str]): Fusion method for feature fusion,
            Options are 'channels_add', 'channel_mul', stand for channelwise
            addition and multiplication respectively. Default: ('channel_add',)
    """

    _abbr_ = 'context_block'

    def __init__(self,
                 in_channels: int,
                 ratio: float,
                 pooling_type: str = 'att',
                 fusion_types: tuple = ('channel_add', )):
        super().__init__()
        assert pooling_type in ['avg', 'att']
        assert isinstance(fusion_types, (list, tuple))
        valid_fusion_types = ['channel_add', 'channel_mul']
        assert all([f in valid_fusion_types for f in fusion_types])
        assert len(fusion_types) > 0, 'at least one fusion should be used'
        self.in_channels = in_channels
        self.ratio = ratio
        self.planes = int(in_channels * ratio)
        self.pooling_type = pooling_type
        self.fusion_types = fusion_types
        if pooling_type == 'att':
            self.conv_mask = nn.Conv2d(in_channels, 1, kernel_size=1)
            self.softmax = nn.Softmax(dim=2)
        else:
            self.avg_pool = nn.AdaptiveAvgPool2d(1)
        if 'channel_add' in fusion_types:
            self.channel_add_conv = nn.Sequential(
                nn.Conv2d(self.in_channels, self.planes, kernel_size=1),
                nn.LayerNorm([self.planes, 1, 1]),
                nn.ReLU(inplace=True),  # yapf: disable
                nn.Conv2d(self.planes, self.in_channels, kernel_size=1))
        else:
            self.channel_add_conv = None
        if 'channel_mul' in fusion_types:
            self.channel_mul_conv = nn.Sequential(
                nn.Conv2d(self.in_channels, self.planes, kernel_size=1),
                nn.LayerNorm([self.planes, 1, 1]),
                nn.ReLU(inplace=True),  # yapf: disable
                nn.Conv2d(self.planes, self.in_channels, kernel_size=1))
        else:
            self.channel_mul_conv = None
        self.reset_parameters()

    def reset_parameters(self):
        if self.pooling_type == 'att':
            kaiming_init(self.conv_mask, mode='fan_in')
            self.conv_mask.inited = True

        if self.channel_add_conv is not None:
            last_zero_init(self.channel_add_conv)
        if self.channel_mul_conv is not None:
            last_zero_init(self.channel_mul_conv)

    def spatial_pool(self, x: torch.Tensor) -> torch.Tensor:
        batch, channel, height, width = x.size()
        if self.pooling_type == 'att':
            input_x = x
            # [N, C, H * W]
            input_x = input_x.view(batch, channel, height * width)
            # [N, 1, C, H * W]
            input_x = input_x.unsqueeze(1)
            # [N, 1, H, W]
            context_mask = self.conv_mask(x)
            # [N, 1, H * W]
            context_mask = context_mask.view(batch, 1, height * width)
            # [N, 1, H * W]
            context_mask = self.softmax(context_mask)
            # [N, 1, H * W, 1]
            context_mask = context_mask.unsqueeze(-1)
            # [N, 1, C, 1]
            context = torch.matmul(input_x, context_mask)
            # [N, C, 1, 1]
            context = context.view(batch, channel, 1, 1)
        else:
            # [N, C, 1, 1]
            context = self.avg_pool(x)

        return context

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # [N, C, 1, 1]
        context = self.spatial_pool(x)

        out = x
        if self.channel_mul_conv is not None:
            # [N, C, 1, 1]
            channel_mul_term = torch.sigmoid(self.channel_mul_conv(context))
            out = out * channel_mul_term
        if self.channel_add_conv is not None:
            # [N, C, 1, 1]
            channel_add_term = self.channel_add_conv(context)
            out = out + channel_add_term

        return out

到了这里,关于GCNet: Global Context Network(ICCV 2019)原理与代码解析的文章就介绍完了。如果您还想了解更多内容,请在右上角搜索TOY模板网以前的文章或继续浏览下面的相关文章,希望大家以后多多支持TOY模板网!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处: 如若内容造成侵权/违法违规/事实不符,请点击违法举报进行投诉反馈,一经查实,立即删除!

领支付宝红包 赞助服务器费用

相关文章

  • 注意力机制(SE, ECA, CBAM, SKNet, scSE, Non-Local, GCNet, ASFF) Pytorch代码

    2023.3.2新增SKNet代码 2023.3.10 新增 scSE代码 2023.3.11 新增 Non-Local Net 非局部神经网络 2023.3.13新增GCNet 2023.6.7新增ASFF SE注意力机制(Squeeze-and-Excitation Networks) :是一种 通道类型 的注意力机制,就是在通道维度上增加注意力机制,主要内容是是 squeeze 和 excitation . 就是使用另外一个

    2024年02月08日
    浏览(44)
  • 【论文笔记】KDD2019 | KGAT: Knowledge Graph Attention Network for Recommendation

    为了更好的推荐,不仅要对user-item交互进行建模,还要将关系信息考虑进来 传统方法因子分解机将每个交互都当作一个独立的实例,但是忽略了item之间的关系(eg:一部电影的导演也是另一部电影的演员) 高阶关系:用一个/多个链接属性连接两个item KG+user-item graph+high orde

    2024年02月16日
    浏览(38)
  • 【深入解析spring cloud gateway】04 Global Filters

    上一节学习了GatewayFilter。 回忆一下一个关键点: GateWayFilterFactory的本质就是:针对配置进行解析,为指定的路由,添加Filter,以便对请求报文进行处理。 GlobalFilter又是啥?先看一下接口定义 再看一下GatewayFilter 可以看到GatewayFilter和GlobalFilter方法签名是一模一样的,那为啥又

    2024年02月09日
    浏览(36)
  • python global函数用法及常用的 global函数代码

      Python中的 global函数是用于在程序中定义变量的函数,在我们实际的开发中,我们可能会用到 global函数来定义变量,但是我们在这里就不具体介绍它的用法了。 global函数定义变量的方法: global函数使用参数a来指定变量在程序中的地址。 参数b表示该变量在程序中的地址。

    2024年02月05日
    浏览(41)
  • PaDiM 原理与代码解析

    paper: PaDiM: a Patch Distribution Modeling Framework for Anomaly Detection and Localization code1: https://github.com/xiahaifeng1995/PaDiM-Anomaly-Detection-Localization-master  code2: https://github.com/openvinotoolkit/anomalib  异常检测是一个区分正常类别与异常类别的二分类问题,但实际应用中我们经常缺乏异常样本

    2024年01月17日
    浏览(24)
  • 动态规划算法:原理、示例代码和解析

    动态规划算法是一种常用的优化问题解决方法,它可以应用于许多计算机科学和其他领域的问题。动态规划算法的基本思想是将一个大问题分解成多个子问题,并将每个子问题的解存储在一个表中。通过计算表中的值,可以得到最终问题的解。在本文中,我们将介绍动态规划

    2024年02月02日
    浏览(39)
  • 扩散模型原理+DDPM案例代码解析

    扩散模型和一般的机器学习的神经网络不太一样!一般的神经网络旨在构造一个网络模型来拟合输入数据与希望得到的输出结果,可以把一般的神经网络当作一个黑盒,这个黑盒通过训练使其输入数据后就可以得到我们想要的结果。而扩散模型包含了大量的统计学和数学相关

    2024年02月16日
    浏览(42)
  • DBNet++(TPAMI) 原理与代码解析

    paper: Real-Time Scene Text Detection with Differentiable Binarization and Adaptive Scale Fusion code1: https://github.com/MhLiao/DB code2: https://github.com/open-mmlab/mmocr 本文是对DBNet的改进,关于DBNet的介绍具体可见场景文本检测算法 可微分二值化DBNet原理与代码解析,本文新提出了一种自适应尺度融合模

    2024年02月16日
    浏览(30)
  • PANet(CVPR 2018)原理与代码解析

    paper:Path Aggregation Network for Instance Segmentation official implementation:GitHub - ShuLiu1993/PANet: PANet for Instance Segmentation and Object Detection  third party implementation:mmdetection/pafpn.py at master · open-mmlab/mmdetection · GitHub  信息在神经网络中的传播方式是非常重要的。本文提出的路径聚合网络(

    2024年02月13日
    浏览(68)
  • [BEV] 学习笔记之BEVDet(原理+代码解析)

    前言 基于LSS的成功,鉴智机器人提出了BEVDet,目前来到了2.0版本,在nuscences排行榜中以mAP=0.586暂列第一名。本文将对BEVDet的原理进行简要说明,然后结合代码对BEVDet进深度解析。 repo: https://github.com/HuangJunJie2017/BEVDet paper:https://arxiv.org/abs/2211.17111 欢迎进入BEV感知交流群,一起

    2024年02月05日
    浏览(45)

觉得文章有用就打赏一下文章作者

支付宝扫一扫打赏

博客赞助

微信扫一扫打赏

请作者喝杯咖啡吧~博客赞助

支付宝扫一扫领取红包,优惠每天领

二维码1

领取红包

二维码2

领红包