代码
import torch
from mmdet.models.backbones import ResNet
from fightingcv_attention.attention.CoordAttention import CoordAtt
from fightingcv_attention.attention.SEAttention import SEAttention
from mmdet.models.builder import BACKBONES
# 定义带attention的resnet18基类
class ResNetWithAttention(ResNet):
def __init__(self , **kwargs):
super(ResNetWithAttention, self).__init__(**kwargs)
# 目前将注意力模块加在最后的三个输出特征层
# resnet输出四个特征层
if self.depth in (18, 34):
self.dims = (64, 128, 256, 512)
elif self.depth in (50, 101, 152):
self.dims = (256, 512, 1024, 2048)
else:
raise Exception()
self.attention1 = self.get_attention_module(self.dims[1])
self.attention2 = self.get_attention_module(self.dims[2])
self.attention3 = self.get_attention_module(self.dims[3])
# 子类只需要实现该attention即可
def get_attention_module(self, dim):
raise NotImplementedError()
def forward(self, x):
outs = super().forward(x)
outs = list(outs)
outs[1] = self.attention1(outs[1])
outs[2] = self.attention2(outs[2])
outs[3] = self.attention3(outs[3])
outs = tuple(outs)
return outs
@BACKBONES.register_module()
class ResNetWithCoordAttention(ResNetWithAttention):
def __init__(self , **kwargs):
super(ResNetWithCoordAttention, self).__init__(**kwargs)
# 子类只需要实现该attention即可
def get_attention_module(self, dim):
return CoordAtt(inp=dim, oup=dim, reduction=32)
@BACKBONES.register_module()
class ResNetWithSEAttention(ResNetWithAttention):
def __init__(self , **kwargs):
super(ResNetWithSEAttention, self).__init__(**kwargs)
# 子类只需要实现该attention即可
def get_attention_module(self, dim):
return SEAttention(channel=dim, reduction=16)
if __name__ == "__main__":
# model = ResNet(depth=18)
# model = ResNet(depth=34)
# model = ResNet(depth=50)
# model = ResNet(depth=101)
# model = ResNet(depth=152)
# model = ResNetWithCoordAttention(depth=18)
model = ResNetWithSEAttention(depth=18)
x = torch.rand(1, 3, 224, 224)
outs = model(x)
# print(outs.shape)
for i, out in enumerate(outs):
print(i, out.shape)
以resnet为例子,我在多个尺度的特征层输出增加注意力机制,以此编写一个基类,子类只需要实现这个attention即可。
参考开源仓库实现attention:
GitHub - xmu-xiaoma666/External-Attention-pytorch: 🍀 Pytorch implementation of various Attention Mechanisms, MLP, Re-parameter, Convolution, which is helpful to further understand papers.⭐⭐⭐
当然也可以直接pip调用:
pip install fightingcv-attention
测试完模型输出后可以利用注册到mmdetection:
简单的方法是,添加backbone注册修饰器,并在train.py和test.py中,import 该文件。
在配置上将model的type从Resnet更改为ResNetWithSEAttention或者ResNetWithSEAttention即可。
注意
(1)这是基于mmdet 2.x版本的改动。
(2)需要注意增加网络的输入输出形状,可能还需要考虑不同宽高输入时网络输出是否与设计一致。
(3)注意力放各尺寸末端是因为想尽可能使用预训练权重(提高预训练权重利用率),如果是重新训练的话,放哪里效果最佳我无法确定。文章来源:https://www.toymoban.com/news/detail-557666.html
(4)增加注意力机制不一定有效涨点,只能是尝试一下。注意力机制解决的问题如果与数据对应,可以尝试一下。譬如需要增加空间长距离特征提取能力可以用空间注意力,或者简单的使用通道注意力都行。如果你的数据量少或者任务较简单,尽量不要用参数较多的模型。文章来源地址https://www.toymoban.com/news/detail-557666.html
到了这里,关于MMDetection中对Resnet增加注意力机制Attention的简单方法的文章就介绍完了。如果您还想了解更多内容,请在右上角搜索TOY模板网以前的文章或继续浏览下面的相关文章,希望大家以后多多支持TOY模板网!