centerpoint论文和代码解读

这篇具有很好参考价值的文章主要介绍了centerpoint论文和代码解读。希望对大家有所帮助。如果存在错误或未考虑完全的地方,请大家不吝赐教,您也可以点击"举报违法"按钮提交疑问。

 

目录

一、序论

二、论文结构

三、代码


论文地址: https://arxiv.org/pdf/2006.11275.pdf

 代码地址:tianweiy/CenterPoint (github.com)

一、序论

centorpoint是一种anchor-free的方法,直接预测物体的中心点,然后直接回归其whl,省去了anchor与GT匹配过程(传统的anchor-base方法需要计算GT和anchor的iou进行分配),同时基于点的预测方便下游跟踪等任务的进行。论文最后的实验表明,该方法对于物体的旋转角度的学习更强一点。因为初始化只有一个点,强迫模型去学习更多的旋转角度信息。反之,anchor-base的方法因为有anchor的先验,所以模型更容易收敛。

二、论文结构

centerpoint 代码,深度学习,人工智能,计算机视觉 

整体的网络架构和pointpillar很像,主要的改动地方在于head部分是anchor-free的。所以我们主要分析的也就是head部分。 

前面的部分,点云经过VFE处理,scatter投影到BEV上,使用FPN的neck对其进行处理得到[B,C,H,W],然后通过一个conv对通道数进行调整,分别经过五个头(其实就是一堆卷积+一个卷积把channel降到需要的维度),得到reg [B,2,W,H] heigh[B,1,w,h] dim [B,3,W,H] rot [B,2,H,W] hm [B,8,H,W]。预测的reg是在一个像素内的偏移,主要是为了

推理时:将dim求指数,根据rot的正余弦值得到角度,将reg与meshgrid生成的坐标相加得到特征图上的绝对坐标。将他们拼接成[B,H*W,7]的box形式,同时对hm求sigmoid,送入后处理,首先对heatmap在channel维度求max,得到其分数和label,根据类别阈值对hm求mask,看哪些能够保留,然后进行NMS过滤掉多余的框,这里我们就说一阶段的,论文里用的两阶段,还有一个box修正阶段。注意:centorpoint使用了NMS

训练时:首先要得到GT的hm和box,所以先0初始化hm [B,8,h,w]  anno_box [B,500,8] ind [B,500] msk [B,500] cat [B,500] 因为每个样本的GT数量不可能一样,所以有的多有的少,统一为500最多,用mask来表示是不是GT,遍历GT个数,根据类别生成相应的hm,高斯半径是根据wh的框的最小iou重叠度确定的,具体见说点Cornernet/Centernet代码里面GT heatmap里面如何应用高斯散射核 - 知乎 (zhihu.com)(分三种,内切,外切,交叉),这里作者限定了高斯半径的最小值。然后看中心点落在哪个pillar里,求个整型做差得到偏移量。对whl求log,对角度求sincos组成anno_box,ind表示该物体中心点在H*W中的下标,cat表示该物体的类别。这样就得到了example。如何画高斯就是用指数的负dist次表示权重,这样离中心点越近,越接近1.

这时有了GT的hm [B,8,h,w]  anno_box [B,500,8] ind [B,500] msk [B,500] cat [B,500]

模型预测的reg [B,2,W,H] heigh[B,1,w,h] dim [B,3,W,H] rot [B,2,H,W] hm [B,8,H,W]

对模型预测的hm进行sigmoid,并组成pred_box[B,8,H*W]这时要把pred_box根据ind用gather转换为[B,8,500],用L1loss计算。而hm则直接用Fastfocalloss计算。文章来源地址https://www.toymoban.com/news/detail-528677.html

三、代码

import logging
from collections import defaultdict
from torch import double, nn
import copy 


import torch
import numpy as np
import torch.nn.functional as F

from ...ops.iou3d_nms import iou3d_nms_cuda
from ..model_utils import model_nms_utils


class Sequential(torch.nn.Module):
    r"""A sequential container.
    Modules will be added to it in the order they are passed in the constructor.
    Alternatively, an ordered dict of modules can also be passed in.

    To make it easier to understand, given is a small example::

        # Example of using Sequential
        model = Sequential(
                  nn.Conv2d(1,20,5),
                  nn.ReLU(),
                  nn.Conv2d(20,64,5),
                  nn.ReLU()
                )

        # Example of using Sequential with OrderedDict
        model = Sequential(OrderedDict([
                  ('conv1', nn.Conv2d(1,20,5)),
                  ('relu1', nn.ReLU()),
                  ('conv2', nn.Conv2d(20,64,5)),
                  ('relu2', nn.ReLU())
                ]))

        # Example of using Sequential with kwargs(python 3.6+)
        model = Sequential(
                  conv1=nn.Conv2d(1,20,5),
                  relu1=nn.ReLU(),
                  conv2=nn.Conv2d(20,64,5),
                  relu2=nn.ReLU()
                )
    """

    def __init__(self, *args, **kwargs):
        super(Sequential, self).__init__()
        if len(args) == 1 and isinstance(args[0], OrderedDict):
            for key, module in args[0].items():
                self.add_module(key, module)
        else:
            for idx, module in enumerate(args):
                self.add_module(str(idx), module)
        for name, module in kwargs.items():
            if sys.version_info < (3, 6):
                raise ValueError("kwargs only supported in py36+")
            if name in self._modules:
                raise ValueError("name exists.")
            self.add_module(name, module)

    def __getitem__(self, idx):
        if not (-len(self) <= idx < len(self)):
            raise IndexError("index {} is out of range".format(idx))
        if idx < 0:
            idx += len(self)
        it = iter(self._modules.values())
        for i in range(idx):
            next(it)
        return next(it)

    def __len__(self):
        return len(self._modules)

    def add(self, module, name=None):
        if name is None:
            name = str(len(self._modules))
            if name in self._modules:
                raise KeyError("name exists")
        self.add_module(name, module)

    def forward(self, input):
        # i = 0
        for module in self._modules.values():
            # print(i)
            input = module(input)
            # i += 1
        return input




def rotate_nms_pcdet(boxes, scores, thresh, pre_maxsize=None, post_max_size=None):
    """
    :param boxes: (N, 7) [x, y, z, l, w, h, theta]
    :param scores: (N)
    :param thresh:
    :return:
    """
    # transform back to pcdet's coordinate
    #将角度转换为openpcdet的坐标
    boxes = boxes[:, [0, 1, 2, 4, 3, 5, -1]]
    boxes[:, -1] = -boxes[:, -1] - np.pi /2

    order = scores.sort(0, descending=True)[1] #将这n个box根据分数从大到小排
    if pre_maxsize is not None:  #如果盒子大于阈值,取前max个
        order = order[:pre_maxsize]

    boxes = boxes[order].contiguous()

    keep = torch.LongTensor(boxes.size(0))

    if len(boxes) == 0:
        num_out =0
    else:
        num_out = iou3d_nms_cuda.nms_gpu(boxes, keep, thresh)

    selected = order[keep[:num_out].cuda()].contiguous()

    if post_max_size is not None:
        selected = selected[:post_max_size]

    return selected 


def kaiming_init(
    module, a=0, mode="fan_out", nonlinearity="relu", bias=0, distribution="normal"
):
    assert distribution in ["uniform", "normal"]
    if distribution == "uniform":
        nn.init.kaiming_uniform_(
            module.weight, a=a, mode=mode, nonlinearity=nonlinearity
        )
    else:
        nn.init.kaiming_normal_(
            module.weight, a=a, mode=mode, nonlinearity=nonlinearity
        )
    if hasattr(module, "bias") and module.bias is not None:
        nn.init.constant_(module.bias, bias)

def gaussian_radius(det_size, min_overlap=0.5):
    """
    compute gaussian radius by min_overlap, you can get principle in <<CenterNet :Objects as Points>> paper
    """
    height, width = det_size  #得到高宽

    a1  = 1
    b1  = (height + width)
    c1  = width * height * (1 - min_overlap) / (1 + min_overlap)
    sq1 = np.sqrt(b1 ** 2 - 4 * a1 * c1)
    r1  = (b1 + sq1) / 2

    a2  = 4
    b2  = 2 * (height + width)
    c2  = (1 - min_overlap) * width * height
    sq2 = np.sqrt(b2 ** 2 - 4 * a2 * c2)
    r2  = (b2 + sq2) / 2

    a3  = 4 * min_overlap
    b3  = -2 * min_overlap * (height + width)
    c3  = (min_overlap - 1) * width * height
    sq3 = np.sqrt(b3 ** 2 - 4 * a3 * c3)
    r3  = (b3 + sq3) / 2
    return min(r1, r2, r3)

def gaussian2D(shape, sigma=1):
    """
    compute gaussian
    """
    m, n = [(ss - 1.) / 2. for ss in shape]
    y, x = np.ogrid[-m:m+1,-n:n+1]  #y[7,1]  x [1,7]

    h = np.exp(-(x * x + y * y) / (2 * sigma * sigma)) # [7,7],离原点越近越大
    h[h < np.finfo(h.dtype).eps * h.max()] = 0  #np.finfo(h.dtype).eps是指非负的最小值
    return h


def draw_umich_gaussian(heatmap, center, radius, k=1):
    """
    draw gaussian in heatmap
    """
    diameter = 2 * radius + 1 #radius
    # compute gaussian value
    gaussian = gaussian2D((diameter, diameter), sigma=diameter / 6) #是一个7*7的矩阵

    x, y = int(center[0]), int(center[1]) #获得整形的中点坐标

    height, width = heatmap.shape[0:2]

    # get gaussian map pos
    left, right = min(x, radius), min(width - x, radius + 1)  #如果xy落在heatmap的边上,离边的距离小于r,就要限制一下防止越界
    top, bottom = min(y, radius), min(height - y, radius + 1)

    # get masked heatmap pos 
    masked_heatmap  = heatmap[y - top:y + bottom, x - left:x + right] # 得到我们要替换heatmap的位置
    masked_gaussian = gaussian[radius - top:radius + bottom, radius - left:radius + right] #得到可用高斯的范围

    # this is used for debug, actuly no use
    if min(masked_gaussian.shape) > 0 and min(masked_heatmap.shape) > 0: # TODO debug
        np.maximum(masked_heatmap, masked_gaussian * k, out=masked_heatmap) #取两者中较大的部分
    return heatmap

def _gather_feat(feat, ind, mask=None):
    dim  = feat.size(2) # 8
    ind  = ind.unsqueeze(2).expand(ind.size(0), ind.size(1), dim) #ind[B,500]--[B,500,1]--[B,500,8] 其表示物体在特征图上的索引
    feat = feat.gather(1, ind)  #根据ind在第一维度H*W找索引ind
    if mask is not None:
        mask = mask.unsqueeze(2).expand_as(feat)
        feat = feat[mask]
        feat = feat.view(-1, dim)
    return feat

def _transpose_and_gather_feat(feat, ind):
    feat = feat.permute(0, 2, 3, 1).contiguous()  # [B,200,380,8]
    feat = feat.view(feat.size(0), -1, feat.size(3)) # [B,H*W,8]
    feat = _gather_feat(feat, ind)
    return feat

def _circle_nms(boxes, min_radius, post_max_size=83):
    """
    NMS according to center distance, no use now
    """
    keep = np.array(circle_nms(boxes.cpu().numpy(), thresh=min_radius))[:post_max_size]

    keep = torch.from_numpy(keep).long().to(boxes.device)

    return keep 


class RegLoss(nn.Module):
  '''Regression loss for an output tensor
    Arguments:
      output (batch x dim x h x w)
      mask (batch x max_objects)
      ind (batch x max_objects)
      target (batch x max_objects x dim)
  '''
  def __init__(self):
    super(RegLoss, self).__init__()
  
  def forward(self, output, mask, ind, target):
    # output[B,8,200,380]  pred[B,500,8]
    # compute mask by ind as not all box number is same and not all grid in use
    pred = _transpose_and_gather_feat(output, ind)
    mask = mask.float().unsqueeze(2) 

    # use L1 loss 两者都是[B,500,8]乘上mask计算loss,然后在B和500维度求和,出来八维的loss
    loss = F.l1_loss(pred*mask, target*mask, reduction='none')
    loss = loss / (mask.sum() + 1e-4)
    loss = loss.transpose(2 ,0).sum(dim=2).sum(dim=1)
    return loss

class FastFocalLoss(nn.Module):
  '''
  Reimplemented focal loss, exactly the same as the CornerNet version.
  Faster and costs much less memory.
  '''
  def __init__(self):
    super(FastFocalLoss, self).__init__()

  def forward(self, out, target, ind, mask, cat):
    '''
    Arguments:
      out, target: B x C x H x W
      ind, mask: B x M
      cat (category id for peaks): B x M
    '''
    mask = mask.float()
    gt = torch.pow(1 - target, 4)
    # compute negtive loss in heatmap
    neg_loss = torch.log(1 - out) * torch.pow(out, 2) * gt
    neg_loss = neg_loss.sum()

    pos_pred_pix = _transpose_and_gather_feat(out, ind) # B x M x C
    pos_pred = pos_pred_pix.gather(2, cat.unsqueeze(2)) # B x M
    num_pos = mask.sum()

    # compute positive loss in heatmap
    pos_loss = torch.log(pos_pred) * torch.pow(1 - pos_pred, 2) * \
               mask.unsqueeze(2)
    pos_loss = pos_loss.sum()
    if num_pos == 0:
      return - neg_loss
    return - (pos_loss + neg_loss) / num_pos



def neg_loss_cornernet(pred, gt, mask=None):
    """
    Refer to https://github.com/tianweiy/CenterPoint.
    Modified focal loss. Exactly the same as CornerNet. Runs faster and costs a little bit more memory
    Args:
        pred: (B x 8 x h x w)
        gt: (B x 8 x h x w)
        mask: (B x h x w)
    Returns:
    """
    pos_inds = gt.eq(1).float() #有物体中心点的地方才为1
    neg_inds = gt.lt(1).float() #不是物体中心的为1

    neg_weights = torch.pow(1 - gt, 4) #[B,8,H,W]  #把负样本的权重设置的很小

    loss = 0

    pos_loss = torch.log(pred) * torch.pow(1 - pred, 2) * pos_inds
    neg_loss = torch.log(1 - pred) * torch.pow(pred, 2) * neg_weights * neg_inds #这样负样本loss会很低

    if mask is not None:
        mask = mask[:, None, :, :].float()
        pos_loss = pos_loss * mask
        neg_loss = neg_loss * mask
        num_pos = (pos_inds.float() * mask).sum()
    else:
        num_pos = pos_inds.float().sum()

    pos_loss = pos_loss.sum()
    neg_loss = neg_loss.sum()

    if num_pos == 0:
        loss = loss - neg_loss
    else:
        loss = loss - (pos_loss + neg_loss) / num_pos  #求完的loss之和除以正样本的个数
    return loss


class FocalLossCenterNet(nn.Module):
    """
    Refer to https://github.com/tianweiy/CenterPoint
    """
    def __init__(self):
        super(FocalLossCenterNet, self).__init__()
        self.neg_loss = neg_loss_cornernet

    def forward(self, out, target, mask=None):
        return self.neg_loss(out, target, mask=mask)



class AssignLabel(object):
    def __init__(self, **kwargs):
        """Return CenterNet training labels like heatmap, height, offset"""

        self.tasks = kwargs["tasks"] #assigner_cfg.target_assigner.tasks

        assigner_cfg = kwargs["cfg"]

        self.out_size_factor = assigner_cfg.out_size_factor # 2
        self.gaussian_overlap = assigner_cfg.gaussian_overlap # 0.1
        self._max_objs = assigner_cfg.max_objs  # 500
        self._min_radius = assigner_cfg.min_radius # 2
        # tasks
        self.class_names = self.tasks["class_names"] # 列表里是八个名字
        self.num_classes = self.tasks["num_class"]  # 8

    def __call__(self, res,  grid_size , voxel_size , pc_range):
        max_objs = self._max_objs   # 500

        feature_map_size = grid_size[:2] // self.out_size_factor  # 得到特征图的长宽
        
        draw_gaussian = draw_umich_gaussian
        # 分别是xyzhwl,yaw,类别
        gt_boxes = res['gt_boxes'].cpu().numpy() # 得到data_dict里的GT  [B,N,8]
        batch_size = res['batch_size']

        # hm is heatmap
        hms, anno_boxs, inds, masks, cats = [], [], [], [], []

        #jinmu: batch one by one compute now
        for batch_idx in range(batch_size):
            batch_box = gt_boxes[batch_idx,...]  #[n,8]
            batch_box_mask = batch_box[...,-1] != 0 # 因为n表示batch里一个样本最多的物体数,有些没有这么多
            #上面这句是指遍历n个物体,最后一维不为0表示有物体
            if np.all(batch_box_mask == False):
                batch_box_valid_num = 0
            else:  # batch_box_mask=[1,1,1,1,0,0,0,0,0]一维的话,np.where只返回列数
                batch_box_valid_num = np.where(batch_box_mask)[0].squeeze().max() + 1 #得到有几个物体

            # c, h, w  [8, 200,380]
            hm = np.zeros((len(self.class_names), feature_map_size[1], feature_map_size[0]),
                            dtype=np.float32)
            # [500, 8]
            anno_box = np.zeros((max_objs, 8), dtype=np.float32)
            # [500]
            ind = np.zeros((max_objs), dtype=np.int64)
            mask = np.zeros((max_objs), dtype=np.uint8) # [500]
            cat = np.zeros((max_objs), dtype=np.int64)  # [500]

            # should keep box number same in different frame to
            # compute in one time, but actualy different frame not 
            # has same box number, so should keep mask
            num_objs = min(batch_box_valid_num, max_objs)  #得到当前帧的物体个数

            for k in range(num_objs):
                cls_id = batch_box[k][-1] - 1  #cls的id
                l, w, h = batch_box[k][3], batch_box[k][4], batch_box[k][5]
                # 得到在特征图上的wl
                w, l = w / voxel_size[1] / self.out_size_factor, l / voxel_size[0] / self.out_size_factor
                if w > 0 and l > 0:  #根据长宽得到高斯半径,根据两个框的最小重叠区,建立r的方程求根,内切外切,一个内一个外
                    radius = gaussian_radius((l, w), min_overlap=self.gaussian_overlap) #wl是浮点数,超参为0.1,得到高斯半径
                    radius = max(self._min_radius, int(radius)) #确保最小的高斯半径为2

                    # 得到中心点在特征图上的坐标
                    x, y, z = batch_box[k][0], batch_box[k][1], batch_box[k][2]
                    coor_x, coor_y = (x - pc_range[0]) / voxel_size[0] / self.out_size_factor, \
                                        (y - pc_range[1]) / voxel_size[1] / self.out_size_factor
                    
                    ct = np.array([coor_x, coor_y], dtype=np.float32)  
                    ct_int = ct.astype(np.int32)  #变为整型

                    # throw out not in range objects to avoid out of array area when creating the heatmap
                    # if beyond range, then continue
                    if not (0 <= ct_int[0] < feature_map_size[0] and 0 <= ct_int[1] < feature_map_size[1]):
                        continue 

                    # draw gaussian in heatmap gt
                    draw_gaussian(hm[int(cls_id)], ct, radius) #画到相应类的heatmap上

                    new_idx = k #表示第k个物体
                    x, y = ct_int[0], ct_int[1]

                    cat[new_idx] = cls_id # 得到相应物体的类别
                    ind[new_idx] = y * feature_map_size[0] + x  # 得到该物体在特征图上的索引
                    mask[new_idx] = 1  #把相应位置的mask赋值为1
                    rot = batch_box[k][6]
                    # fill regression target, ct - (x,y) is x_offset and y_offset
                    # rot is yaw angle
                    anno_box[new_idx] = np.concatenate(
                        (ct - (x, y), z, np.log(batch_box[k][3:6]),
                        np.sin(rot), np.cos(rot)), axis=None)  #得到当前heatmap的xy偏移,whl,sincos,

            hms.append(hm)
            anno_boxs.append(anno_box)
            masks.append(mask)
            inds.append(ind)
            cats.append(cat)

        hms = torch.from_numpy(np.stack(hms)).cuda() #将数组沿着第0维堆叠
        anno_boxs = torch.from_numpy(np.stack(anno_boxs)).cuda()
        inds = torch.from_numpy(np.stack(inds)).cuda()
        cats = torch.from_numpy(np.stack(cats)).cuda()
        masks = torch.from_numpy(np.stack(masks)).cuda()
        # [B,8,h,w]   [B,500,8]  [B,500,1] [B,500,1] [B,500,1]
        example = {'hm': hms, 'anno_box': anno_boxs, 'ind': inds, 'mask': masks, 'cat': cats}

        return example


class SepHead(nn.Module):
    """
    this is seqhead that contains actual head like (heatmap) (lxoffset yoffset) (z) (dim) (cos(theta) sin(theta))
    """
    def __init__(
        self,
        in_channels,
        heads,
        head_conv=64,
        final_kernel=1,
        bn=False,
        init_bias=-2.19,
        **kwargs,
    ):
        super(SepHead, self).__init__(**kwargs)

        self.heads = heads # {'reg': [2, 2], 'height': [1, 2], 'dim': [3, 2], 'rot': [2, 2], 'hm': [8, 2]}
        for head in self.heads:  #遍历的是键
            classes, num_conv = self.heads[head] #根据键得到值,第一个最终的channel数,用来回归的,第二个是几个conv

            fc = Sequential()
            # layers number decided by config
            for i in range(num_conv-1):
                fc.add(nn.Conv2d(in_channels, head_conv,
                    kernel_size=final_kernel, stride=1, 
                    padding=final_kernel // 2, bias=True))  #
                if bn:
                    fc.add(nn.BatchNorm2d(head_conv))
                fc.add(nn.ReLU())

            # output conv
            fc.add(nn.Conv2d(head_conv, classes,
                    kernel_size=final_kernel, stride=1, 
                    padding=final_kernel // 2, bias=True))    
            # hm的偏置是固定的,其余的开明初始化
            if 'hm' in head:
                fc[-1].bias.data.fill_(init_bias)
            else:
                for m in fc.modules():
                    if isinstance(m, nn.Conv2d):
                        kaiming_init(m)
            # 每个头都有两个卷积,再接一个卷积用来得到预测结果channel维度
            # python method, 设置完可以用getattr通过head调用fc
            self.__setattr__(head, fc)
        

    def forward(self, x):
        ret_dict = dict()        
        for head in self.heads:
            ret_dict[head] = self.__getattr__(head)(x)
        #ret_dict是一个字典 reg:[B,2,200,380] height [B,1,200,380] dim [B,3,200,380] rot [B,2,200,380] hm [B,8,200,380]
        return ret_dict


class CenterHead(nn.Module):
    def __init__(
        self,
        model_cfg,
        input_channels=[128,],
        num_class=1,
        class_names=None,
        grid_size=[0.32,0.32,0.16],
        point_cloud_range=None,
        predict_boxes_when_training=False,
        logger=None,
        init_bias=-2.19,
        num_hm_conv=2,
    ):
        super(CenterHead, self).__init__()
        assert(len(class_names) == num_class)
        
        tasks = dict(num_class=num_class, class_names=class_names)
        self.label_assigner = AssignLabel(cfg=model_cfg.TARGET_ASSIGNER_CONFIG, tasks=tasks)
        
        self.out_size_factor = model_cfg.TARGET_ASSIGNER_CONFIG.out_size_factor # 2
        self.model_cfg = model_cfg

        self.class_names = [class_names] #class_name本来是一个列表现在[[a,b,c,,,,]]
        self.num_classes = [num_class]  # [8]

        self.code_weights = model_cfg.code_weights #[5.0, 1.0, 1.0, 1.0, 1.0, 1.0, 5.0, 5.0]
        self.weight = model_cfg.weight # 0.25 
        
        self.in_channels = input_channels # 384

        #self.crit = FastFocalLoss()
        self.crit = FocalLossCenterNet()
        self.crit_reg = RegLoss()

        

        common_heads = model_cfg.common_heads #{'reg': [ 2, 2 ],'height': [ 1, 2 ],'dim': [ 3, 2 ],'rot': [ 2, 2 ]}

        self.box_n_dim = 9 if 'vel' in common_heads else 7  # 7
        self.use_direction_classifier = False 

        if not logger:
            logger = logging.getLogger("CenterHead")
        self.logger = logger

        logger.info(
            f"num_classes: {self.num_classes}"
        )

        # a shared convolution 
        share_conv_channel = 64 if "share_conv_channel" not in model_cfg else model_cfg.share_conv_channel # 64
        self.shared_conv = nn.Sequential(
            nn.Conv2d(self.in_channels, share_conv_channel,
            kernel_size=3, padding=1, bias=True),
            nn.BatchNorm2d(share_conv_channel),
            nn.ReLU(inplace=True)
        )

        self.tasks = nn.ModuleList()
        print("Use HM Bias: ", init_bias)

        for num_cls in self.num_classes:  #[8]相当于就遍历一个8
            heads = copy.deepcopy(common_heads) 
            heads.update(dict(hm=(num_cls, num_hm_conv))) #{'reg': [2, 2], 'height': [1, 2], 'dim': [3, 2], 'rot': [2, 2], 'hm': [8, 2]}
            self.tasks.append(
                SepHead(share_conv_channel, heads, bn=True, init_bias=init_bias, final_kernel=3)
            )

        self.frozen_param = model_cfg.FROZON_PARAM
        self.frozen_parameters()

        logger.info("Finish CenterHead Initialization")

    def forward(self, data_dict, *kwargs):

        x = data_dict['spatial_features_2d'] # [B, 384, 200, 380]
        x = self.shared_conv(x)  #先将channel变为64
        ret_dicts = []

        for task in self.tasks:
            ret_dicts.append(task(x))
        # reg [B,2,W,H] heigh[B,1,w,h] dim [B,3,W,H] rot [B,2,H,W] hm [B,8,H,W] 是一个字典
        data_dict['centerhead_preds'] = ret_dicts

        return data_dict

    def _sigmoid(self, x):
        y = torch.clamp(x.sigmoid_(), min=1e-4, max=1-1e-4)
        return y

    def loss(self, data_dict, **kwargs):
        #是一个字典根据GT生成的 hm[B,8,H,W],anno_box [B,n,8] ind[B,n] mask[B,n] cat[B,n]
        example = self.label_assigner(data_dict, kwargs["grid_size"], kwargs["voxel_size"], kwargs["pc_range"])

        # get centerhead output reg[B,2,200,380] heigh[B,1,200,380] dim [B,3,200,380] rot [B,2,200,380] hm [B,8,200,380]
        preds_dicts = data_dict['centerhead_preds']

        assert(len(preds_dicts) == 1)
        # TODO refactor this
        preds_dict = preds_dicts[0] #本来是一个数组,得到字典
        
        # apply sigmoid for heatmap output
        preds_dict['hm'] = self._sigmoid(preds_dict['hm']) #对heatmap预测加上sigmoid,自定义的sigmoid,防止梯度消失
        # hm_loss = self.crit(
        #     preds_dict['hm'], 
        #     example['hm'], 
        #     example['ind'], 
        #     example['mask'], 
        #     example['cat']
        #     )
        
        hm_loss = self.crit(preds_dict['hm'], example['hm']) #使用focallosscenternet

        target_box = example['anno_box']
        # not care about vel as not vel now
        if 'vel' in preds_dict:
            preds_dict['anno_box'] = torch.cat((preds_dict['reg'], preds_dict['height'], preds_dict['dim'],
                                                preds_dict['vel'], preds_dict['rot']), dim=1)  
        else:
            preds_dict['anno_box'] = torch.cat((preds_dict['reg'], preds_dict['height'], preds_dict['dim'],
                                                preds_dict['rot']), dim=1)   

        # Regression loss for dimension, offset, height, rotation  得到长度为8的loss张量          
        box_loss = self.crit_reg(preds_dict['anno_box'], example['mask'], example['ind'], target_box)
        box_loss = box_loss * box_loss.new_tensor(self.code_weights) #这样可以使后面的张量拥有和前面一样的属性
        
        reg_loss = box_loss[:2]
        height_loss = box_loss[2]
        dim_loss = box_loss[2:5]
        rot_loss = box_loss[5:]
        
        loc_loss = box_loss.sum()
        loc_loss *= self.weight

        # total loss
        loss = hm_loss + loc_loss
        #ret = {'loss': loss, 'hm_loss': hm_loss, 'loc_loss':loc_loss, 'loc_loss_elem': box_loss.detach().cpu(), 'num_positive': example['mask'][0].float().sum()}
        # ret = {'hm_loss': hm_loss, 'loc_loss': loc_loss, 
        #         'reg_loss': reg_loss, 'height_loss': height_loss, 
        #         'dim_loss': dim_loss, 'rot_loss': rot_loss}

        ret = {'hm_loss': hm_loss, 'loc_loss': loc_loss}
        
        return ret
    
    def frozen_parameters(self):
        if self.frozen_param:
            for parameter in self.parameters():
                parameter.requires_grad = False

    @torch.no_grad()
    def predict(self, preds_dicts, test_cfg, **kwargs):
        """decode, nms, then return the detection result.
        """

        voxel_size = kwargs["voxel_size"]
        pc_range = kwargs["pc_range"]

        post_center_range = pc_range
        # reg [B,2,W,H] heigh[B,1,w,h] dim [B,3,W,H] rot [B,2,H,W] hm [B,8,H,W] 是一个字典
        preds_dicts = preds_dicts['centerhead_preds']

        if len(post_center_range) > 0:
            post_center_range = torch.tensor(
                post_center_range,
                dtype=preds_dicts[0]['hm'].dtype,
                device=preds_dicts[0]['hm'].device,
            )

        rets = []
        #jinmu now only support one task
        for task_id, preds_dict in enumerate(preds_dicts):
            # convert B C H W to B H W C 
            for key, val in preds_dict.items():
                preds_dict[key] = val.permute(0, 2, 3, 1).contiguous()

            batch_size = preds_dict['hm'].shape[0]
            batch_hm = torch.sigmoid(preds_dict['hm'])

            # exp for dim output to keep dim > 0
            batch_dim = torch.exp(preds_dict['dim']) #dim is h, w, d

            # cos(theta) and sin(theta)
            batch_rots = preds_dict['rot'][..., 0:1]
            batch_rotc = preds_dict['rot'][..., 1:2]

            # x offset and y offset output
            batch_reg = preds_dict['reg']
            # z output
            batch_hei = preds_dict['height']

            # atan to recover true theta
            batch_rot = torch.atan2(batch_rots, batch_rotc) #根据正余弦得到角度

            batch, H, W, num_cls = batch_hm.size()

            # reshape for compute convenient
            batch_reg = batch_reg.reshape(batch, H*W, 2)
            batch_hei = batch_hei.reshape(batch, H*W, 1)

            batch_rot = batch_rot.reshape(batch, H*W, 1)
            batch_dim = batch_dim.reshape(batch, H*W, 3)
            batch_hm = batch_hm.reshape(batch, H*W, num_cls) #把hw放一块方便计算

            #compute x and y axies for each grid for later to recover lidar axies x y with 
            # x_offset and y_offset
            ys, xs = torch.meshgrid([torch.arange(0, H), torch.arange(0, W)])
            ys = ys.view(1, H, W).repeat(batch, 1, 1).to(batch_hm.device).float()
            xs = xs.view(1, H, W).repeat(batch, 1, 1).to(batch_hm.device).float()

            # x y  + x_offset y_offset to recover continuous x y value
            xs = xs.view(batch, -1, 1) + batch_reg[:, :, 0:1]
            ys = ys.view(batch, -1, 1) + batch_reg[:, :, 1:2]

            xs = xs * self.out_size_factor * voxel_size[0] + pc_range[0]
            ys = ys * self.out_size_factor * voxel_size[1] + pc_range[1]

            # jinmu: not care aboud this as we has not vel output now
            if 'vel' in preds_dict:
                batch_vel = preds_dict['vel']
                batch_vel = batch_vel.reshape(batch, H*W, 2)
                batch_box_preds = torch.cat([xs, ys, batch_hei, batch_dim, batch_vel, batch_rot], dim=2)
            else: 
                batch_box_preds = torch.cat([xs, ys, batch_hei, batch_dim, batch_rot], dim=2)

            if test_cfg.get('per_class_nms', False):
                pass 
            else:
                rets.append(self.post_processing(batch_box_preds, batch_hm, test_cfg, post_center_range)) 

        assert(len(rets) == 1) # only one task

        return rets[0]

    @torch.no_grad()
    def post_processing(self, batch_box_preds, batch_hm, test_cfg, post_center_range):
        batch_size = len(batch_hm)
        # batch_box_preds [B,H*W,7] batch_hm [B,H*W,8]
        prediction_dicts = []
        for i in range(batch_size):  #一个一个batch处理
            box_preds = batch_box_preds[i]
            hm_preds = batch_hm[i]

            # score and label is get as max operation in heatmap #在八个维度里取个max
            scores, labels = torch.max(hm_preds, dim=-1) #得到最大分数和最大分数的下标(也就是类别)形状都为[H*W]

            # score mask is get as > score_thresh
            #score_mask = scores > test_cfg.score_threshold 
            score_threshold = torch.tensor(test_cfg.score_threshold)[labels] #得到H*W对应类别的thresh
            score_mask = scores > score_threshold.cuda() #如果这个分数大于阈值,就判定为正样本

            # distance_mask means that noly keep 3d box center in some range
            # not use this in perception postprocess code
            distance_mask = (box_preds[..., :3] >= post_center_range[:3]).all(1) \
                & (box_preds[..., :3] <= post_center_range[3:]).all(1)

            # mask is intersection of two mask
            mask = distance_mask & score_mask 

            # get masked data
            box_preds = box_preds[mask] #得到H*W个box里符合要求的
            scores = scores[mask]
            labels = labels[mask]

            # get box for nms, each box in [x y z dx dy dz theta] format
            boxes_for_nms = box_preds[:, [0, 1, 2, 3, 4, 5, -1]]

            # bev rotated box nms
            selected = rotate_nms_pcdet(boxes_for_nms, scores, 
                                thresh=test_cfg.nms.nms_iou_threshold,
                                pre_maxsize=test_cfg.nms.nms_pre_max_size,
                                post_max_size=test_cfg.nms.nms_post_max_size)

            # selected is box mask after nms
            selected_boxes = box_preds[selected]
            selected_scores = scores[selected]
            selected_labels = labels[selected]

            # fill result, selected_boxes: n * 7, selected_scores: n * 1,
            # selected_labels: n * 1
            record_dict = {
                'pred_boxes': selected_boxes,
                'pred_scores': selected_scores,
                'pred_labels': selected_labels + 1
            }

            prediction_dicts.append(record_dict)

        return prediction_dicts 

到了这里,关于centerpoint论文和代码解读的文章就介绍完了。如果您还想了解更多内容,请在右上角搜索TOY模板网以前的文章或继续浏览下面的相关文章,希望大家以后多多支持TOY模板网!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处: 如若内容造成侵权/违法违规/事实不符,请点击违法举报进行投诉反馈,一经查实,立即删除!

领支付宝红包 赞助服务器费用

相关文章

  • 基于LIDC-IDRI肺结节肺癌数据集的人工智能深度学习分类良性和恶性肺癌(Python 全代码)全流程解析(二)

    第一部分内容的传送门 环境配置建议使用anaconda进行配置。核心的配置是keras和tensorflow的版本要匹配。 环境配置如下: tensorboard 1.13.1 tensorflow 1.13.1 Keras 2.2.4 numpy 1.21.5 opencv-python 4.6.0.66 python 3.7 数据集的预处理分为两个关键步骤。首先是图片处理,我们使用cv2库将图片转换为

    2024年04月29日
    浏览(39)
  • 人工智能基础部分11-图像识别实战(网络层联想记忆,代码解读)

    大家好,我叫微学AI,今天给大家带来图像识别实战项目。 图像识别实战是一个实际应用项目,下面介绍如何使用深度学习技术来识别和检测图像中的物体。主要涉及计算机视觉,实时图像处理和相关的深度学习算法。学习者将学习如何训练和使用深度学习模型来识别和检测

    2024年02月05日
    浏览(43)
  • 【必看】揭秘AI革命背后的力量!550篇人工智能核心论文深度解析

    大家好,我是你们的知识探索者,今天我带来了一个前所未有的宝藏分享——一份涵盖了550篇人工智能领域核心论文的终极指南!这不仅仅是一份文档,而是一扇通往人工智能世界深处的大门。 ** ** 🌟 为什么这550篇论文至关重要? 在人工智能的浪潮中,无数的研究和实验层

    2024年02月21日
    浏览(68)
  • 人工智能深度学习

    目录 人工智能 深度学习 机器学习 神经网络 机器学习的范围 模式识别 数据挖掘 统计学习 计算机视觉 语音识别 自然语言处理 机器学习的方法 回归算法 神经网络 SVM(支持向量机) 聚类算法 降维算法 推荐算法 其他 机器学习的分类 机器学习模型的评估 机器学习的应用 机

    2024年02月22日
    浏览(58)
  • 人工智能之深度学习

    第一章 人工智能概述 1.1人工智能的概念和历史 1.2人工智能的发展趋势和挑战 1.3人工智能的伦理和社会问题 第二章 数学基础 1.1线性代数 1.2概率与统计 1.3微积分 第三章 监督学习 1.1无监督学习 1.2半监督学习 1.3增强学习 第四章 深度学习 1.1神经网络的基本原理 1.2深度学习的

    2024年02月09日
    浏览(55)
  • 人工智能、机器学习、深度学习的区别

    人工智能涵盖范围最广,它包含了机器学习;而机器学习是人工智能的重要研究内容,它又包含了深度学习。 人工智能是一门以计算机科学为基础,融合了数学、神经学、心理学、控制学等多个科目的交叉学科。 人工智能是一门致力于使计算机能够模拟、模仿人类智能的学

    2024年02月08日
    浏览(56)
  • 机器学习入门教学——人工智能、机器学习、深度学习

    1、人工智能 人工智能相当于人类的代理人,我们现在所接触到的人工智能基本上都是弱AI,主要作用是正确解释从外部获得的数据,并对这些数据加以学习和利用,以便灵活的实现特定目标和任务。 例如: 阿尔法狗、智能汽车 简单来说: 人工智能使机器像人类一样进行感

    2024年02月09日
    浏览(91)
  • 探索人工智能:深度学习、人工智能安全和人工智能编程(文末送书)

    人工智能知识对于当今的互联网技术人来说已经是刚需。但人工智能的概念、流派、技术纷繁复杂,选择哪本书入门最适合呢? 这部被誉为人工智能“百科全书”的《人工智能(第3版)》,可以作为每个技术人进入 AI 世界的第一本书。 购书链接,限时特惠5折 这本书是美国

    2024年02月03日
    浏览(120)
  • 一探究竟:人工智能、机器学习、深度学习

    1.1 人工智能是什么?          1956年在美国Dartmounth 大学举办的一场研讨会中提出了人工智能这一概念。人工智能(Artificial Intelligence),简称AI,是计算机科学的一个分支,它企图了解智能的实质,并生产出一种新的能以人类智能相似的方式做出反应的智能机器,该领域的

    2024年02月17日
    浏览(53)
  • 12、人工智能、机器学习、深度学习的关系

    很多年前听一个机器学习的公开课,在QA环节,一个同学问了老师一个问题“ 机器学习和深度学习是什么关系 ”? 老师先没回答,而是反问了在场的同学,结果问了2-3个,没有人可以回答的很到位,我当时也是初学一脸懵,会场准备的小礼品也没有拿到。 后来老师解释“机

    2024年02月05日
    浏览(72)

觉得文章有用就打赏一下文章作者

支付宝扫一扫打赏

博客赞助

微信扫一扫打赏

请作者喝杯咖啡吧~博客赞助

支付宝扫一扫领取红包,优惠每天领

二维码1

领取红包

二维码2

领红包