DETR3D代码阅读

这篇具有很好参考价值的文章主要介绍了DETR3D代码阅读。希望对大家有所帮助。如果存在错误或未考虑完全的地方,请大家不吝赐教,您也可以点击"举报违法"按钮提交疑问。

前言

本文主要是自己在阅读DETR3D的源码时的一个记录,如有错误或者问题,欢迎指正

提取feature map

在projects\mmdet3d_plugin\models\detectors\detr3d.py的forward_train()中,首先通过res50和FPN来进行图片特征的提取

        img_feats = self.extract_feat(img=img, img_metas=img_metas)
        losses = dict()
        losses_pts = self.forward_pts_train(img_feats, gt_bboxes_3d,
                                            gt_labels_3d, img_metas,
                                            gt_bboxes_ignore)
        losses.update(losses_pts)

提取到的img_feats为 [num_level,bs,6,c,h,w]
DETR3D代码阅读
然后调用self.forward_pts_train,进入到self.forward_pts_train中,首先调用self.pts_bbox_head来计算前向过程的输出,代码由此进入到detr3d_head中

detr3d_head

进入transformer

  hs, init_reference, inter_references = self.transformer(
            mlvl_feats,  #经过resnet和FPN提取到的多尺度特征
            query_embeds, #[900,512]  [num_query,embed_dims*2]
            reg_branches=self.reg_branches if self.with_box_refine else None,  # noqa:E501 reg_banches是回归分支
            img_metas=img_metas,
        )

DETR3D这里的transformer只有decoder,没有encoder,整个transformer的代码如下:

		assert query_embed is not None
        bs = mlvl_feats[0].size(0)
        # 首先将query_embed分为query 和 query_pos
        query_pos, query = torch.split(query_embed, self.embed_dims , dim=1)
        query_pos = query_pos.unsqueeze(0).expand(bs, -1, -1)  # [1, 900, 256] 1是batch_size
        query = query.unsqueeze(0).expand(bs, -1, -1)   # [1, 900, 256] 1是batch_size
        # 通过linear和sigmoid从query_pos中获取到reference_points
        reference_points = self.reference_points(query_pos)
        reference_points = reference_points.sigmoid()   #[1, 900, 3] [bs,num_query,3]
        init_reference_out = reference_points

        # decoder
        query = query.permute(1, 0, 2)      #[1,900,256] --> [900, 1, 256] [num_query,bs,256]
        query_pos = query_pos.permute(1, 0, 2)  #[1,900,256] --> [900, 1, 256][num_query,bs,256]

      	# 进入到Detr3DTransformerDecoder
        inter_states, inter_references = self.decoder(
            query=query, # [num_query,bs,256]
            key=None,   
            value=mlvl_feats, #value就是提取出的图片特征 [num_level,bs,num_cam,c,h,w]
            query_pos=query_pos, # [num_query,bs,256]
            reference_points=reference_points, #[bs,num_query,3]
            reg_branches=reg_branches,
            **kwargs)

        inter_references_out = inter_references
        #inter_states是sample到的feature   inter_references_out是更新后的referencepoints
        return inter_states, init_reference_out, inter_references_out

decoder

decoder中先做self_attn,此时QKV都是query,shape为[900,bs,265],然后做cross_atten,在cross_attn中

        if key is None:
            key = query   # key就等于query
        if value is None:
            value = key

        if residual is None:
            inp_residual = query
        if query_pos is not None:
            query = query + query_pos

key和query是一样的,value是多尺度的feature map,虽然这里有了key,但是其实也没有用K乘Q去计算attention_weight,其attention_weight依然是通过query出的。

        query = query.permute(1, 0, 2)      # (1,900,256)

        bs, num_query, _ = query.size()      #bs=1, num_query=900

        # (1,1,900,12,1,4) num_cams=12 num_points=1 num_levels=4
        attention_weights = self.attention_weights(query).view(
            bs, 1, num_query, self.num_cams, self.num_points, self.num_levels)

        # 返回值reference_points_3d就是原来的3d坐标,output是sampled_feats
        # output=B, C, num_query, num_cam,  1, len(mlvl_feats)] reference_points_3d=[1,900,3]
        reference_points_3d, output, mask = feature_sampling(
            value, reference_points, self.pc_range, kwargs['img_metas'])
        output = torch.nan_to_num(output)
        mask = torch.nan_to_num(mask)

        attention_weights = attention_weights.sigmoid() * mask
        # 个人理解:这里的output就是attention中的value
        output = output * attention_weights
        # 连续三个sum(-1),将不同尺度和不同相机的feature求和,得到最终图像的特征
        output = output.sum(-1).sum(-1).sum(-1)
        output = output.permute(2, 0, 1)
        # 图像特征project到与query同维度 [bs,256,num_query,num_cam,1,num_level] --> [num_query, bs, 256]
        output = self.output_proj(output)     
        # (num_query, bs, embed_dims)
        pos_feat = self.position_encoder(inverse_sigmoid(reference_points_3d)).permute(1, 0, 2) #MLP
		
		# 输出 = sampled到的feature,原来的query,reference_points_3d的pos_feat
        return self.dropout(output) + inp_residual + pos_feat

最重要的部分就是feature_sampling这个函数

def feature_sampling(mlvl_feats, reference_points, pc_range, img_metas):
    lidar2img = []
    for img_meta in img_metas:
        lidar2img.append(img_meta['lidar2img'])
    lidar2img = np.asarray(lidar2img)
    lidar2img = reference_points.new_tensor(lidar2img) # (B, N, 4, 4) (1,6,4,4)
    reference_points = reference_points.clone()
    reference_points_3d = reference_points.clone()
    # 归一化坐标  pc_range =[-51.2, -51.2, -5.0, 51.2, 51.2, 3.0]
    # 将坐标从0-1尺度转到lidar坐标系下
    reference_points[..., 0:1] = reference_points[..., 0:1]*(pc_range[3] - pc_range[0]) + pc_range[0]  #x轴
    reference_points[..., 1:2] = reference_points[..., 1:2]*(pc_range[4] - pc_range[1]) + pc_range[1]  #y轴
    reference_points[..., 2:3] = reference_points[..., 2:3]*(pc_range[5] - pc_range[2]) + pc_range[2]  #z轴

    # reference_points (B, num_queries, 4)   在最后一列全加上1 变成(1,900,4)
    reference_points = torch.cat((reference_points, torch.ones_like(reference_points[..., :1])), -1)
    ###############################################
    # 2.由lidar系转化为camera系 
    ###############################################
    B, num_query = reference_points.size()[:2]         #B=1,num_query=900
    num_cam = lidar2img.size(1)   # num_cam=6
    # reference_points[1,900,4] --> reference_points.view(B, 1, num_query, 4) [1,1,900,4] --> repeat [1,12,900,4] --> [1,12,900,4,1]
    reference_points = reference_points.view(B, 1, num_query, 4).repeat(1, num_cam, 1, 1).unsqueeze(-1)
    lidar2img = lidar2img.view(B, num_cam, 1, 4, 4).repeat(1, 1, num_query, 1, 1) #[1, 12, 900, 4, 4]
    # reference_points_cam.size() = [1,6,900,4]
    reference_points_cam = torch.matmul(lidar2img, reference_points).squeeze(-1)
    ###############################################
    # 3.由camera系转到图像系并归一化
    ###############################################
    eps = 1e-5
    # mask.size() = [1,6,900,1]
    mask = (reference_points_cam[..., 2:3] > eps)
    # 这一步是将坐标由camera系转到图像系 (x,y) = (xc,yc) / zc *f  这里的f是相机焦距,在前面lidar2img已经成过了,这里只用除以zc就行了
    reference_points_cam = reference_points_cam[..., 0:2] / torch.maximum(
        reference_points_cam[..., 2:3], torch.ones_like(reference_points_cam[..., 2:3])*eps)  # 深度归一化 		(1, 6, 900, 2)
     # 在img平面上进行长宽归一化
    reference_points_cam[..., 0] /= img_metas[0]['img_shape'][0][1] # 长宽归一化
    reference_points_cam[..., 1] /= img_metas[0]['img_shape'][0][0]
	#将坐标由[0,1] 转到[-1,1]之间
    reference_points_cam = (reference_points_cam - 0.5) * 2
    # 对所有不在grid内的点,也就是投影在某个cam之外的点进行mask
    mask = (mask & (reference_points_cam[..., 0:1] > -1.0) 
                 & (reference_points_cam[..., 0:1] < 1.0) 
                 & (reference_points_cam[..., 1:2] > -1.0) 
                 & (reference_points_cam[..., 1:2] < 1.0))
    # mask.size()=[1,1,900,6,1,1]
    mask = mask.view(B, num_cam, 1, num_query, 1, 1).permute(0, 2, 3, 1, 4, 5)
    mask = torch.nan_to_num(mask)
    sampled_feats = []
    # 逐特征层sample feature
    for lvl, feat in enumerate(mlvl_feats):
        B, N, C, H, W = feat.size()   #B=1,N=6 C=256 H=16 W=28
        feat = feat.view(B*N, C, H, W)
        reference_points_cam_lvl = reference_points_cam.view(B*N, num_query, 1, 2)
        sampled_feat = F.grid_sample(feat, reference_points_cam_lvl)
        sampled_feat = sampled_feat.view(B, N, C, num_query, 1).permute(0, 2, 3, 1, 4)
        sampled_feats.append(sampled_feat)
    sampled_feats = torch.stack(sampled_feats, -1)
    sampled_feats = sampled_feats.view(B, C, num_query, num_cam,  1, len(mlvl_feats))
    return reference_points_3d, sampled_feats, mask

在通过feature_sampling提取特征后,将得到的output首先和attention_weigth相乘,然后连续三个sum(-1),将不同相机,不同尺度的feature直接相加,将这些特征都融合再一起,再通过一个out_proj的线性层,将这些特征转换到与query相同的维度,最后的输出就是提取到的特征和原始的query以及reference_points3d的位置编码的和。

在经过每一层的decoder layer之后,会有一个回归分支来预测bbox,会根据预测出的bbox来更新reference point

整个Detr3DTransformerDecoder的代码如下:

  		output = query
        intermediate = []
        intermediate_reference_points = []
        for lid, layer in enumerate(self.layers):
            reference_points_input = reference_points
            # 由此进入DetrTransformerDecoderLayer
            # 返回的output为[900,1,256]
            output = layer(
                output,
                *args,
                reference_points=reference_points_input,
                **kwargs)
            output = output.permute(1, 0, 2)
			
            if reg_branches is not None:
                tmp = reg_branches[lid](output)
                
                assert reference_points.shape[-1] == 3

                new_reference_points = torch.zeros_like(reference_points)
				# x y
                new_reference_points[..., :2] = tmp[
                    ..., :2] + inverse_sigmoid(reference_points[..., :2])
                 # z
                new_reference_points[..., 2:3] = tmp[
                    ..., 4:5] + inverse_sigmoid(reference_points[..., 2:3])
                
                new_reference_points = new_reference_points.sigmoid()

                reference_points = new_reference_points.detach()

            output = output.permute(1, 0, 2)
            if self.return_intermediate:
                intermediate.append(output)
                intermediate_reference_points.append(reference_points)

        if self.return_intermediate:
            return torch.stack(intermediate), torch.stack(
                intermediate_reference_points)

        return output, reference_points

整个decoder部分返回的是每一个decoder layer输出的query和reference point

走完整个transformer之后,后面就是通过输出的每一层的feature和referencepoint来进行预测。文章来源地址https://www.toymoban.com/news/detail-469338.html

到了这里,关于DETR3D代码阅读的文章就介绍完了。如果您还想了解更多内容,请在右上角搜索TOY模板网以前的文章或继续浏览下面的相关文章,希望大家以后多多支持TOY模板网!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处: 如若内容造成侵权/违法违规/事实不符,请点击违法举报进行投诉反馈,一经查实,立即删除!

领支付宝红包 赞助服务器费用

相关文章

  • DETR训练自己的数据集

    DETR是一个利用transformer实现端到端目标检测的模型。本文记录利用官方提供的代码来训练验证自己的数据集的过程以及一些注意事项。 此次项目用到的数据集为自己制造的自动驾驶领域的路况数据集,该数据集一共包含57个类别: names = [ \\\"i2\\\", \\\"i4\\\", \\\"i5\\\", \\\"il100\\\", \\\"il60\\\", \\\"il80\\\", \\\"io

    2023年04月17日
    浏览(42)
  • 【DETR】训练自己的数据集-实践笔记

    DETR(Detection with TRansformers)是基于transformer的端对端目标检测,无NMS后处理步骤,无anchor。 实现使用NWPUVHR10数据集训练DETR. NWPU数据集总共包含十种类别目标,包含650个正样本,150个负样本(没有用到)。 代码:https://github.com/facebookresearch/detr 1.数据集准备 DETR需要的数据集格式

    2024年02月05日
    浏览(45)
  • 3d激光slam建图与定位(2)_aloam代码阅读

    1.常用的几种loam算法 aloam 纯激光 lego_loam 纯激光 去除了地面 lio_sam imu+激光紧耦合 lvi_sam 激光+视觉 2.代码思路 2.1.特征点提取scanRegistration.cpp,这个文件的目的是为了根据曲率提取4种特征点和对当前点云进行预处理 输入是雷达点云话题 输出是 4种特征点点云和预处理后的当前

    2024年02月11日
    浏览(39)
  • 计算机视觉算法——基于Transformer的目标检测(DETR / Deformable DETR / Dynamic DETR / DETR 3D)

    DETR是DEtection TRansformer的缩写,该方法发表于2020年ECCV,原论文名为《End-to-End Object Detection with Transformers》。 传统的目标检测是基于Proposal、Anchor或者None Anchor的方法,并且至少需要非极大值抑制来对网络输出的结果进行后处理,涉及到复杂的调参过程。而DETR使用了Transformer

    2024年02月09日
    浏览(55)
  • RT-DETR论文阅读笔记(包括YOLO版本训练和官方版本训练)

    论文地址: RT-DETR论文地址 代码地址: RT-DETR官方下载地址 大家如果想看更详细训练、推理、部署、验证等教程可以看我的另一篇博客里面有更详细的介绍 内容回顾: 详解RT-DETR网络结构/数据集获取/环境搭建/训练/推理/验证/导出/部署  目录 一、介绍  二、相关工作 2.1、实

    2024年02月03日
    浏览(42)
  • DETR 系列有了新发现?DETRs with Hybrid Matching 论文阅读笔记

    写在前面   有个城市之星的活动,加紧赶一篇博文出来吧。这是 VALSE 2023 大会(VALSE 2023 无锡线下参会个人总结 6月11日-2)上的一篇 Poster 论文,遂找来读读。 论文地址:DETRs with Hybrid Matching 代码地址:https://github.com/HDETR 收录于:CVPR 2023 PS:2023 每周一篇博文,主页 更多干

    2024年02月07日
    浏览(47)
  • 【colab】谷歌colab免费服务器训练自己的模型,本文以yolov5为例介绍流程

    目录 一.前言 二.准备工作 1.注册Google drive(谷歌云盘) Google Driver官网:https://drive.google.com/drive/ Colab官网:https://colab.research.google.com/ 2.上传项目文件 3.安装Colaboratory 4.colab相关操作和命令 5.项目相关操作  三.异常处理         本文介绍了在谷歌开放平台Google colab上租用免

    2023年04月08日
    浏览(53)
  • 【微信小程序】如何获得自己当前的定位呢?本文利用逆地址解析、uni-app带你实现

    目录 前言 效果展示 一、在腾讯定位服务配置微信小程序JavaScript SDK 二、使用uni-app获取定位的经纬度 三、 逆地址解析,获取精确定位 四、小提示 在浏览器搜索腾讯定位服务,找到官方网站,利用微信或者其他账号注册登录,登录后如下图操作 点进去之后,可以看到如下图

    2024年01月19日
    浏览(87)
  • 人工智能学习07--pytorch23--目标检测:Deformable-DETR训练自己的数据集

    1、pytorch conda create -n deformable_detr python=3.9 pip 2、激活环境 conda activate deformable_detr 3、torch 4、其他的库 pip install -r requirements.txt 5、编译CUDA cd ./models/ops sh ./make.sh #unit test (should see all checking is True) python test.py (我没运行这一步) 主要是MultiScaleDeformableAttention包,如果中途换了

    2024年02月14日
    浏览(159)
  • Linux系列文章 —— vim的基本操作(误入vim退出请先按「ESC」再按:q不保存退出,相关操作请阅读本文)

    vim-操作篇 进程概念篇 进程地址空间篇 Linux,是一种免费使用和自由传播的类UNIX操作系统,是一个基于POSIX的多用户、多任务、支持多线程和多CPU的操作系统。它能运行主要的Unix工具软件、应用程序和网络协议。Linux继承了Unix以网络为核心的设计思想,是一个性能稳定的多用

    2024年02月03日
    浏览(46)

觉得文章有用就打赏一下文章作者

支付宝扫一扫打赏

博客赞助

微信扫一扫打赏

请作者喝杯咖啡吧~博客赞助

支付宝扫一扫领取红包,优惠每天领

二维码1

领取红包

二维码2

领红包