Deformable DETR源码解读

这篇具有很好参考价值的文章主要介绍了Deformable DETR源码解读。希望对大家有所帮助。如果存在错误或未考虑完全的地方,请大家不吝赐教,您也可以点击"举报违法"按钮提交疑问。

一:网络创新点

传统DETR提出的encoder-decoder结构,将transformer运用到了目标检测领域,在我看来属于Resnet相对于Alexnet的里程碑级别思路很开辟但是细节还欠打磨,我分析一下DETR中的缺点:

  • 收敛速度慢。因为keys的选取自整个特征图上的每个像素点,复杂度是指数级别的暴增。注意力初始分布十分平均,Dense-to-Sparse的效果不好。
  • 精度不高,特别是对于小目标检测效果更差。原因用论文中的话说,the deficits of Transformer attention in handling image feature maps as key elements,Modern object detectors use high-resolution feature maps to better detect small objects. However, high-resolution feature maps would lead to an unacceptable complexity for the self-attention module in the Transformer encoder of DETR, which has a quadratic complexity with the spatial size of input feature maps。究其原因是特征图处理模块少,也没有什么类似FPN这种低维和高维特征融合的手段。

针对以上的几个问题,Deformable DETR依次提出如下思路:

  • key的选取不再是全图所有的像素点,而是每一个query在特征图上对应一个reference_point,基于每个reference_point再选取n = 4(源码中设置)个keys,根据Linear生成的attention_weights进行特征融合(注意注意力权重不是Q*k算来的,而是对query直接Linear得到的)。这样大大提高了收敛速度,而是有选择性的注意Sparse区域来训练attention
  • 为了提高小目标检测效果,没有使用FPN,而是提取了backbone中C3~C5和用3✖3 kernel_size、(2, 2)stride得到的C6这四个特征图,每个query的head在这四个各取4个key,然后融合更新
  • 后期作者还增加了Iterative Bounding Box Refinement,根据decoder上一层Layer输出结果,迭代更新bounding box,大大提高了预测准确率。
  • 作者还增加了two-stage升级版结构,回到了检测的经典思想中,性能参数都有一定提高。由于较复杂,这里暂不讲。

二:流程详解

【part 1】deformable_detr模块

  • 首先分析deformable_detr模块,从backbone的C3~C5提取出3个srcs和pos_embeds,将C5进行stride=2的下采样操作,得到第四个src和pos_embed。然后对四个srcs进行Linear,把channels变为hidden_dim,得到下图结果,pos_embeds的shape和变换通道后的srcs的shape相同:
    Deformable DETR源码解读
  • deformable_detr模块还初始化了query_embeds,self.query_embed = nn.Embedding(num_queries, hidden_dim*2),即(10, 128),10是代码中设置的query_num。值得注意的是128,因为这里的self.query_embed一半是tgt,一半是pos_embeds。
  • 将它们传给deformable_transformer模块中,self.transformer(srcs, masks, pos, query_embeds)

【part 2】deformable_transformer模块

  • 首先对传入的数据做flatten()处理,打印如下:
    Deformable DETR源码解读
  • 接着将处理后的数据传入encoder模块中, memory = self.encoder(src_flatten, spatial_shapes, level_start_index, valid_ratios, lvl_pos_embed_flatten, mask_flatten),让我们一起进入encoder模块看一看
【part3】Encoder模块

Deformable DETR源码解读

  • 首先通过self.get_reference_points(spatial_shapes, valid_ratios, device=src.device)得到reference_points,shape为 [2, 15060, 4 , 2],得到的是在每一层特征图中的相对位置(0 ~ 1)。处理方法如下:
    Deformable DETR源码解读
  • 接下来进入EncoderLayer层中,传入数据的shape可见图,EncoderLayer的forward结构如下:
    Deformable DETR源码解读下面让我们重点看一下网络核心模块MSDeformAttn,对应着self.self_attn()
【part 4】MSDeformAttn

Deformable DETR源码解读

  • 就是将加了pos_embeds的srcs作为query传入,通过Linear生成sampling_offsets和attention_weights,分别对应着每个query的每个head在每个特征层选取的4个keys和权重,可见这里的weight不是QK后生成的,而是直接Linear得到的。

  • 最后传入MSDeformAttnFunction功能模块进行特征融合,实现细节略,输出memory。

  • 结束了encoder模块,输出了memory。退回到deformable_transformer模块:
    Deformable DETR源码解读

  • 可见,就是将10个query_embed做了一下复制、拆分,得到真正的query_embed(decoder中也作为query_pos)和tgt,接着将query_embed传入Linear中得到reference_points,最后都传入Decoder中

【part5】Decoder模块

Deformable DETR源码解读Deformable DETR源码解读

  • 简单处理一下reference_points后,循环进入DecoderLayer中,可以对中间output和reference_points储存,如果加了bbox refinement那么reference_points会一次次改变。Layer结构如下:

Deformable DETR源码解读

  • 先是自注意力,注意这里没有使用MSDeformAttn,而是正常的MutiheadAttention。然后交叉注意力,得到最终结果。

最后,让我们回到Deformable_Detr模块,从self.transformer中输出结果如下:

Deformable DETR源码解读后面根据任务转换输出结果的channels,之后就是基本的匈牙利匹配➕损失计算了,和Detr差不多。有一点值得注意,bbox的pred结果是reference_point + self.bbox_embed(hs[i])[…,:2]。相当于网络输出预测是长、宽和基于reference_point的偏移量!!!


  至此我对Deformable DETR源码中全部的流程与细节,进行了深度讲解,希望对大家有所帮助,有不懂的地方或者建议,欢迎大家在下方留言评论。

我是努力在CV泥潭中摸爬滚打的江南咸鱼,我们一起努力,不留遗憾!文章来源地址https://www.toymoban.com/news/detail-439315.html

到了这里,关于Deformable DETR源码解读的文章就介绍完了。如果您还想了解更多内容,请在右上角搜索TOY模板网以前的文章或继续浏览下面的相关文章,希望大家以后多多支持TOY模板网!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处: 如若内容造成侵权/违法违规/事实不符,请点击违法举报进行投诉反馈,一经查实,立即删除!

领支付宝红包 赞助服务器费用

相关文章

  • 计算机视觉算法——基于Transformer的目标检测(DETR / Deformable DETR / Dynamic DETR / DETR 3D)

    DETR是DEtection TRansformer的缩写,该方法发表于2020年ECCV,原论文名为《End-to-End Object Detection with Transformers》。 传统的目标检测是基于Proposal、Anchor或者None Anchor的方法,并且至少需要非极大值抑制来对网络输出的结果进行后处理,涉及到复杂的调参过程。而DETR使用了Transformer

    2024年02月09日
    浏览(54)
  • DPText-DETR原理及源码解读(二)

    理解中。。。 接下来深入最难的DeformableTransformer_Det,这个py文件包含了多个class DeformableTransformer_Det DeformableTransformerEncoderLayer DeformableTransformerEncoder CirConv 环形卷积 DeformableTransformerDecoderLayer_Det DeformableTransformerDecoder_Det 多尺度可变形注意力 先说结论:随batch增大,显存占用变

    2024年02月05日
    浏览(44)
  • DEFORMABLE DETR: DEFORMABLE TRANSFORMERS FOR END-TO-END OBJECT DETECTION 论文精读笔记

    DEFORMABLE DETR: DEFORMABLE TRANSFORMERS FOR END-TO-END OBJECT DETECTION 参考:AI-杂货铺-Transformer跨界CV又一佳作!Deformable DETR:超强的小目标检测算法! 摘要 摘要部分,作者主要说明了如下几点: 为了解决DETR中使用Transformer架构在处理图像特征图时的局限性而导致的收敛速度慢,特征空间

    2024年02月10日
    浏览(39)
  • 【RT-DETR有效改进】结合SOTA思想利用双主干网络改进RT-DETR(全网独家创新,重磅更新)

    本文给大家带来的改进机制是结合目前 SOTAYOLOv9的思想 利用双主干网络来改进RT-DETR(本专栏目前发布以来改进最大的内容,同时本文内容为我个人一手整理全网独家首发 | 就连V9官方不支持的模型宽度和深度修改我都均已提供, 本文内容支持RT-DETR全系列模型均可使用 ) ,本

    2024年03月16日
    浏览(59)
  • Transformer实战-系列教程21:DETR 源码解读8 损失计算:(SetCriterion类)

    有任何问题欢迎在下面留言 本篇文章的代码运行界面均在Pycharm中进行 本篇文章配套的代码资源已经上传 点我下载源码 DETR 算法解读 DETR 源码解读1(项目配置/CocoDetection类/ConvertCocoPolysToMask类) DETR 源码解读2(DETR类) DETR 源码解读3(位置编码:Joiner类/PositionEmbeddingSine类)

    2024年02月19日
    浏览(43)
  • 人工智能学习07--pytorch23--目标检测:Deformable-DETR训练自己的数据集

    1、pytorch conda create -n deformable_detr python=3.9 pip 2、激活环境 conda activate deformable_detr 3、torch 4、其他的库 pip install -r requirements.txt 5、编译CUDA cd ./models/ops sh ./make.sh #unit test (should see all checking is True) python test.py (我没运行这一步) 主要是MultiScaleDeformableAttention包,如果中途换了

    2024年02月14日
    浏览(158)
  • 详细理解(学习笔记) | DETR(整合了Transformer的目标检测框架) DETR入门解读以及Transformer的实操实现

    DETR ,全称 DEtection TRansformer,是Facebook提出的基于Transformer的端到端目标检测网络,发表于ECCV2020。 原文: 链接 源码: 链接 DETR 端到端目标检测网络模型,是第一个将 Transformer 成功整合为检测pipline中心构建块的目标检测框架模型。基于Transformers的端到端目标检测,没有NMS后

    2024年02月04日
    浏览(56)
  • 【RT-DETR有效改进】ShapeIoU、InnerShapeIoU关注边界框本身的IoU(包含二次创新)

    👑欢迎大家订阅本专栏,一起学习RT-DETR👑  本文给大家带来的改进机制是ShapeIoU其是一种关注边界框本身形状和尺度的边界框回归方法(IoU),同时本文的内容包括过去到现在的百分之九十以上的损失函数的实现,使用方法非常简单,在本文的末尾还会教大家在改进模型时

    2024年01月16日
    浏览(95)
  • 【西安交通大学】:融合传统与创新的学府之旅

    🎉博客主页:小智_x0___0x_ 🎉欢迎关注:👍点赞🙌收藏✍️留言 🎉系列专栏:小智带你闲聊 🎉代码仓库:小智的代码仓库 西安交通大学是国家教育部直属重点大学,为我国最早兴办的高等学府之一。其前身是1896年创建于上海的南洋公学,1921年改称交通大学,1956年国务院

    2024年02月15日
    浏览(43)
  • 【RT-DETR改进涨点】MPDIoU、InnerMPDIoU损失函数中的No.1(包含二次创新)

    👑欢迎大家订阅本专栏,一起学习RT-DETR👑  本文给大家带来的改进机制是 最新的 损失函数 MPDIoU (Minimum Point Distance Intersection over Union) 其是 一种新的边界框相似度度量方法 。MPDIoU是基于水平矩形的最小点距离来计算的,能够综合考虑重叠区域、中心点距离以及宽度和高

    2024年01月16日
    浏览(49)

觉得文章有用就打赏一下文章作者

支付宝扫一扫打赏

博客赞助

微信扫一扫打赏

请作者喝杯咖啡吧~博客赞助

支付宝扫一扫领取红包,优惠每天领

二维码1

领取红包

二维码2

领红包