原创:余晓龙
“Swin Transformer: Hierarchical Vision Transformer using Shifted Window”是微软亚洲研究院(MSRA)发表在arXiv上的论文,文中提出了一种新型的Transformer架构,也就是Swin Transformer。本文旨在对Swin Transformer架构进行详细解析。
一、Swin Transformer网络架构
整体的网络架构采取层次化的设计,共包含4个stage,每个stage都会缩小输入特征图的分辨率,类似于CNN操作逐层增加感受野。对于一张输入图像224x224x3,首先会像VIT一样,把图片打成patch,这里Swin transformer中使用的patch size大小为 4x4,不同于VIT中使用的大小为16x16。经过Patch Partition,图像的大小会变成56 x 56 x 48, 其中48为 (4x4x3)3 为图片的rgb通道。打完patch之后会经过Linear Embedding,这里的主要目的是为了把向量的维度变成我们预先设定好的值,即可以满足transformer可以输入的值。在Swin-T网络中,这里C的大小为96,得到的网络输出值为56x56x96。之后经过拉直,序列长度变成3136 x 96。其代码如下:
class PatchEmbed(nn.Module):
"""
2D Image to Patch Embedding
"""
def __init__(self, patch_size=4, in_c=3, embed_dim=96, norm_layer=None):
super().__init__()
patch_size = (patch_size, patch_size)
self.patch_size = patch_size
self.in_chans = in_c
self.embed_dim = embed_dim
self.proj = nn.Conv2d(in_c, embed_dim, kernel_size=patch_size, stride=patch_size)
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
def forward(self, x):
_, _, H, W = x.shape
# padding
# 如果输入图片的H,W不是patch_size的整数倍,需要进行padding
pad_input = (H % self.patch_size[0] != 0) or (W % self.patch_size[1] != 0)
if pad_input:
# to pad the last 3 dimensions,
# (W_left, W_right, H_top,H_bottom, C_front, C_back)
x = F.pad(x, (0, self.patch_size[1] - W % self.patch_size[1],
0, self.patch_size[0] - H % self.patch_size[0],
0, 0))
# 下采样patch_size倍
x = self.proj(x)
_, _, H, W = x.shape
# flatten: [B, C, H, W] -> [B, C, HW]
# transpose: [B, C, HW] -> [B, HW, C]
x = x.flatten(2).transpose(1, 2)
x = self.norm(x)
return x, H, W
if __name__ == '__main__':
x = torch.randn(8, 3, 224, 224)
x, W, H = PatchEmbed()(x)
print(x.size()) # torch.Size([8, 3136, 96])
print(W) # 56
print(H) # 56
Swin transformer引入了基于窗口的自注意力计算,每个窗口为 7x7=49个patch。如果想要有多尺度的特征信息,就需要构建一个层级式的transformer,类似卷积神经网络里的池化操作,Patch Merging用于缩小分辨率,调整通道数,完成层级式的设计。这里每次的降采样为2,在行和列方向每隔一个点选取元素,之后拼接在一起展开。
相当于在空间上的维度去换到了更多的通道数,维度变成4C,之后在C的维度上利用全连接层,将通道数的大小变成2C,经过上述操作之后网络输出的大小变为28 x 28 x 192。之后经过拉直,序列长度变成784 x 192。后面的stage3、stage4同理。最终的网络输出的大小变为7 x 7 x 768。之后经过拉直,序列长度变成49 x 768。代码如下:
class PatchMerging(nn.Module):
r""" Patch Merging Layer.
Args:
dim (int): Number of input channels.
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
"""
def __init__(self, dim, norm_layer=nn.LayerNorm):
super().__init__()
self.dim = dim
self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
self.norm = norm_layer(4 * dim)
def forward(self, x, H, W):
"""
x: B, H*W, C
"""
B, L, C = x.shape
assert L == H * W, "input feature has wrong size"
x = x.view(B, H, W, C)
# padding
# 如果输入feature map的H,W不是2的整数倍,需要进行padding
pad_input = (H % 2 == 1) or (W % 2 == 1)
if pad_input:
# to pad the last 3 dimensions, starting from the last dimension and moving forward.
# (C_front, C_back, W_left, W_right, H_top, H_bottom)
# 注意这里的Tensor通道是[B, H, W, C],所以会和官方文档有些不同
x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2))
x0 = x[:, 0::2, 0::2, :] # [B, H/2, W/2, C]
x1 = x[:, 1::2, 0::2, :] # [B, H/2, W/2, C]
x2 = x[:, 0::2, 1::2, :] # [B, H/2, W/2, C]
x3 = x[:, 1::2, 1::2, :] # [B, H/2, W/2, C]
x = torch.cat([x0, x1, x2, x3], -1) # [B, H/2, W/2, 4*C]
x = x.view(B, -1, 4 * C) # [B, H/2*W/2, 4*C]
x = self.norm(x)
x = self.reduction(x) # [B, H/2*W/2, 2*C]
return x
if __name__ == '__main__':
x = torch.randn(8, 3, 224, 224)
x, H, W = PatchEmbed()(x)
# print(x.size()) # torch.Size([8, 3136, 96])
# print(W) # 56
# print(H) # 56
x = PatchMerging(dim=96)(x, H, W)
print(x.size()) # torch.Size([8, 784, 192])
基于窗口/移动窗口的自注意力
由于全局的自注意力计算会导致平方倍的复杂度,因此作者提出了基于窗口的自注意力机制。原来的图片会被平均分成一些没有重叠的窗口,以第一层的输入为例,尺寸大小为56 x 56 x 96。
在每一个小方格中会有7x7=49个patch,因此大的特征图可以分为 56 / 7 x 56 / 7 = 8 x 8 个窗口。
基于窗口的自注意力机制与基于全局的自注意力机制复杂度对比:
以标准的多头自注意力为例, 对于一个输入,自注意力首先会将它变成q, k, v三个向量,之后得到的q, k 相乘得到attention,在有了自注意力之后后和得到的v进行相乘,相当于做了一次加权,最后因为这是使用了多头自注意力机制,还会经过一个projection layer,这个投射层就会把向量的维度投射到我们想要的那个维度,如下图:
公式一 :
3 h w c 2 + ! ( h w ) 2 c + ( h w ) 2 c + h w c 2 3hwc^{2} + ! (hw)^{2}c + (hw)^{2}c + hwc^{2} 3hwc2+!(hw)2c+(hw)2c+hwc2
公式二:基于窗口的自注意力复杂度 一个窗口大小 M x M 代入公式一得文章来源:https://www.toymoban.com/news/detail-736909.html
4 M 2 c 2 + 2 M 4 c 4M^{2}c^{2} + 2M^{4}c 4M2c2+2M文章来源地址https://www.toymoban.com/news/detail-736909.html
到了这里,关于Swin Transformer详解的文章就介绍完了。如果您还想了解更多内容,请在右上角搜索TOY模板网以前的文章或继续浏览下面的相关文章,希望大家以后多多支持TOY模板网!