Swin Transformer详解

这篇具有很好参考价值的文章主要介绍了Swin Transformer详解。希望对大家有所帮助。如果存在错误或未考虑完全的地方,请大家不吝赐教,您也可以点击"举报违法"按钮提交疑问。

原创:余晓龙

“Swin Transformer: Hierarchical Vision Transformer using Shifted Window”是微软亚洲研究院(MSRA)发表在arXiv上的论文,文中提出了一种新型的Transformer架构,也就是Swin Transformer。本文旨在对Swin Transformer架构进行详细解析。

一、Swin Transformer网络架构

swin transformer,transformer,深度学习,计算机视觉,人工智能,机器学习

整体的网络架构采取层次化的设计,共包含4个stage,每个stage都会缩小输入特征图的分辨率,类似于CNN操作逐层增加感受野。对于一张输入图像224x224x3,首先会像VIT一样,把图片打成patch,这里Swin transformer中使用的patch size大小为 4x4,不同于VIT中使用的大小为16x16。经过Patch Partition,图像的大小会变成56 x 56 x 48, 其中48为 (4x4x3)3 为图片的rgb通道。打完patch之后会经过Linear Embedding,这里的主要目的是为了把向量的维度变成我们预先设定好的值,即可以满足transformer可以输入的值。在Swin-T网络中,这里C的大小为96,得到的网络输出值为56x56x96。之后经过拉直,序列长度变成3136 x 96。其代码如下:

class PatchEmbed(nn.Module):
    """
    2D Image to Patch Embedding
    """
    def __init__(self, patch_size=4, in_c=3, embed_dim=96, norm_layer=None):
        super().__init__()
        patch_size = (patch_size, patch_size)
        self.patch_size = patch_size
        self.in_chans = in_c
        self.embed_dim = embed_dim
        self.proj = nn.Conv2d(in_c, embed_dim, kernel_size=patch_size, stride=patch_size)
        self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()

    def forward(self, x):
        _, _, H, W = x.shape

        # padding
        # 如果输入图片的H,W不是patch_size的整数倍,需要进行padding
        pad_input = (H % self.patch_size[0] != 0) or (W % self.patch_size[1] != 0)
        if pad_input:
            # to pad the last 3 dimensions,
            # (W_left, W_right, H_top,H_bottom, C_front, C_back)
            x = F.pad(x, (0, self.patch_size[1] - W % self.patch_size[1],
                          0, self.patch_size[0] - H % self.patch_size[0],
                          0, 0))

        # 下采样patch_size倍
        x = self.proj(x)
        _, _, H, W = x.shape
        # flatten: [B, C, H, W] -> [B, C, HW]
        # transpose: [B, C, HW] -> [B, HW, C]
        x = x.flatten(2).transpose(1, 2)
        x = self.norm(x)
        return x, H, W


if __name__ == '__main__':
    x = torch.randn(8, 3, 224, 224)
    x, W, H = PatchEmbed()(x)
    print(x.size())  # torch.Size([8, 3136, 96])
    print(W)  # 56
    print(H)  # 56

Swin transformer引入了基于窗口的自注意力计算,每个窗口为 7x7=49个patch。如果想要有多尺度的特征信息,就需要构建一个层级式的transformer,类似卷积神经网络里的池化操作,Patch Merging用于缩小分辨率,调整通道数,完成层级式的设计。这里每次的降采样为2,在行和列方向每隔一个点选取元素,之后拼接在一起展开。

swin transformer,transformer,深度学习,计算机视觉,人工智能,机器学习

相当于在空间上的维度去换到了更多的通道数,维度变成4C,之后在C的维度上利用全连接层,将通道数的大小变成2C,经过上述操作之后网络输出的大小变为28 x 28 x 192。之后经过拉直,序列长度变成784 x 192。后面的stage3、stage4同理。最终的网络输出的大小变为7 x 7 x 768。之后经过拉直,序列长度变成49 x 768。代码如下:

class PatchMerging(nn.Module):
    r""" Patch Merging Layer.

    Args:
        dim (int): Number of input channels.
        norm_layer (nn.Module, optional): Normalization layer.  Default: nn.LayerNorm
    """

    def __init__(self, dim, norm_layer=nn.LayerNorm):
        super().__init__()
        self.dim = dim
        self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
        self.norm = norm_layer(4 * dim)

    def forward(self, x, H, W):
        """
        x: B, H*W, C
        """
        B, L, C = x.shape
        assert L == H * W, "input feature has wrong size"

        x = x.view(B, H, W, C)

        # padding
        # 如果输入feature map的H,W不是2的整数倍,需要进行padding
        pad_input = (H % 2 == 1) or (W % 2 == 1)
        if pad_input:
            # to pad the last 3 dimensions, starting from the last dimension and moving forward.
            # (C_front, C_back, W_left, W_right, H_top, H_bottom)
            # 注意这里的Tensor通道是[B, H, W, C],所以会和官方文档有些不同
            x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2))

        x0 = x[:, 0::2, 0::2, :]  # [B, H/2, W/2, C]
        x1 = x[:, 1::2, 0::2, :]  # [B, H/2, W/2, C]
        x2 = x[:, 0::2, 1::2, :]  # [B, H/2, W/2, C]
        x3 = x[:, 1::2, 1::2, :]  # [B, H/2, W/2, C]
        x = torch.cat([x0, x1, x2, x3], -1)  # [B, H/2, W/2, 4*C]
        x = x.view(B, -1, 4 * C)  # [B, H/2*W/2, 4*C]

        x = self.norm(x)
        x = self.reduction(x)  # [B, H/2*W/2, 2*C]

        return x


if __name__ == '__main__':                            
    x = torch.randn(8, 3, 224, 224)                   
    x, H, W = PatchEmbed()(x)                         
    # print(x.size())  # torch.Size([8, 3136, 96])    
    # print(W)  # 56                                  
    # print(H)  # 56                                  
                                                      
    x = PatchMerging(dim=96)(x, H, W)                 
    print(x.size())       # torch.Size([8, 784, 192]) 

基于窗口/移动窗口的自注意力

由于全局的自注意力计算会导致平方倍的复杂度,因此作者提出了基于窗口的自注意力机制。原来的图片会被平均分成一些没有重叠的窗口,以第一层的输入为例,尺寸大小为56 x 56 x 96。

swin transformer,transformer,深度学习,计算机视觉,人工智能,机器学习

在每一个小方格中会有7x7=49个patch,因此大的特征图可以分为 56 / 7 x 56 / 7 = 8 x 8 个窗口。

基于窗口的自注意力机制与基于全局的自注意力机制复杂度对比:

swin transformer,transformer,深度学习,计算机视觉,人工智能,机器学习

以标准的多头自注意力为例, 对于一个输入,自注意力首先会将它变成q, k, v三个向量,之后得到的q, k 相乘得到attention,在有了自注意力之后后和得到的v进行相乘,相当于做了一次加权,最后因为这是使用了多头自注意力机制,还会经过一个projection layer,这个投射层就会把向量的维度投射到我们想要的那个维度,如下图:

swin transformer,transformer,深度学习,计算机视觉,人工智能,机器学习

公式一 :
3 h w c 2 + ! ( h w ) 2 c + ( h w ) 2 c + h w c 2 3hwc^{2} + ! (hw)^{2}c + (hw)^{2}c + hwc^{2} 3hwc2+!(hw)2c+(hw)2c+hwc2

公式二:基于窗口的自注意力复杂度 一个窗口大小 M x M 代入公式一得

4 M 2 c 2 + 2 M 4 c 4M^{2}c^{2} + 2M^{4}c 4M2c2+2M文章来源地址https://www.toymoban.com/news/detail-736909.html

到了这里,关于Swin Transformer详解的文章就介绍完了。如果您还想了解更多内容,请在右上角搜索TOY模板网以前的文章或继续浏览下面的相关文章,希望大家以后多多支持TOY模板网!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处: 如若内容造成侵权/违法违规/事实不符,请点击违法举报进行投诉反馈,一经查实,立即删除!

领支付宝红包 赞助服务器费用

相关文章

  • Swin-Transformer 详解

    由于Transformer的大火,相对应的也出来了许多文章,但是这些文章的速度和精度相较于CNN还是差点意思,2021年微软研究院发表在ICCV上的一篇文章Swin Transformer是Transformer模型在视觉领域的又一次碰撞,Swin Transformer可能是CNN的完美替代方案。 论文名称:Swin Transformer: Hierarchical

    2024年02月04日
    浏览(29)
  • Swin-Transformer(原理 + 代码)详解

    图解Swin Transformer Swin-Transformer网络结构详解 【机器学习】详解 Swin Transformer (SwinT) 论文下载 官方源码下载 学习的话,请下载 Image Classification 的代码,配置相对简单,其他的配置会很麻烦。如下图所示: Install : pytorch安装:感觉pytorch 1.4版本都没问题的。 2、pip install timm==

    2023年04月08日
    浏览(34)
  • Swin Transformer之相对位置编码详解

    目录 一、概要 二、具体解析 1. 相对位置索引计算第一步  2. 相对位置索引计算第二步 3. 相对位置索引计算第三步      在 Swin Transformer 采用了 相对位置编码 的概念。       那么相对位置编码的作用是什么呢?           解释: 在解释相对位置编码之前,我们需要先了解

    2023年04月16日
    浏览(28)
  • Swin-Transformer网络结构详解

    Swin Transformer是2021年微软研究院发表在ICCV上的一篇文章,并且已经获得 ICCV 2021 best paper 的荣誉称号。Swin Transformer网络是Transformer模型在视觉领域的又一次碰撞。该论文一经发表就已在多项视觉任务中霸榜。该论文是在2021年3月发表的,现在是2021年11月了,根据官方提供的信息

    2024年02月04日
    浏览(30)
  • AI大模型之Swin Transformer 最强CV图解(深度好文)

    目录 SwinTransformer之CV模型详解 第一代CV大模型:Vision Transformer 第二代CV大模型:Swin Transformer 两代模型PK(VIT和Swin Transformer) Swin Transformer是什么CV模型? Swin Transformer应用场景是什么? Swin Transformer到底解决了什么问题? Swin Transformer网络架构 Patch Embbeding介绍 window_partition介绍

    2024年04月28日
    浏览(21)
  • 论文学习笔记:Swin Transformer: Hierarchical Vision Transformer using Shifted Windows

    论文阅读:Swin Transformer: Hierarchical Vision Transformer using Shifted Windows 今天学习的论文是 ICCV 2021 的 best paper,Swin Transformer,可以说是 transformer 在 CV 领域的一篇里程碑式的工作。文章的标题是一种基于移动窗口的层级 vision transformer。文章的作者都来自微软亚研院。 Abstract 文章的

    2024年02月08日
    浏览(29)
  • Swin-transformer论文阅读笔记(Swin Transformer: Hierarchical Vision Transformer using Shifted Windows)

    论文标题:Swin Transformer: Hierarchical Vision Transformer using Shifted Windows 论文作者:Ze Liu, Yutong Lin, Yue Cao, Han Hu, Yixuan Wei, Zheng Zhang, Stephen Lin, Baining Guo 论文来源:ICCV 2021,Paper 代码来源:Code 目录 1. 背景介绍 2. 研究现状 CNN及其变体 基于自注意的骨干架构 自注意/Transformer来补充CN

    2024年02月07日
    浏览(37)
  • VIT与swin transformer

    VIT也就是vision transformer的缩写。是第一种将transformer运用到计算机视觉的网络架构。其将注意力机制也第一次运用到了图片识别上面。其结构图如下(采用的是paddle公开视频的截图) 看起来比较复杂,但实际上总体流程还是比较简单的。只需要看最右边的总的结构图,它的输

    2024年02月05日
    浏览(30)
  • YOLOv5+Swin Transformer

    参考:(7条消息) 改进YOLOv5系列:3.YOLOv5结合Swin Transformer结构,ICCV 2021最佳论文 使用 Shifted Windows 的分层视觉转换器_芒果汁没有芒果的博客-CSDN博客 本科生工科生cv改代码 本来做的7,但是7报错一直解决不了,我就试试5 1、先是第一个报错 解决:在yolo.py里 2、 解决:common里删

    2024年02月12日
    浏览(25)
  • 关于Swin Transformer的架构记录

    Swin Transformer 可以说是批着Transformer外表的卷积神经网络。 具体的架构如下图所示: 首先我们得到一张224*224*3的图片。 通过分成4*4的patch,变成了56*56*48。 线性变换后又变成了56*56*96。 然后利用了Swin Transformer中一个比较特别的结构 Patch Merging 变成28*28*192。 同理,变成14*14*3

    2024年02月20日
    浏览(24)

觉得文章有用就打赏一下文章作者

支付宝扫一扫打赏

博客赞助

微信扫一扫打赏

请作者喝杯咖啡吧~博客赞助

支付宝扫一扫领取红包,优惠每天领

二维码1

领取红包

二维码2

领红包