Openpcdet 的POST_PROCESSING模块
在OpenPCDet中,POST_PROCESSING模块是用于在模型输出的点云检测结果上进行后处理的组件。
该模块主要负责对检测结果进行滤波、聚类、追踪等操作,以提高检测的准确性和稳定性。
POST_PROCESSING模块通常包含以下几个主要的子模块或步骤:
-
点云滤波(Point Cloud Filtering):这一步骤用于去除原始点云中的噪声和离群点,常用的滤波方法包括体素下采样(Voxel Downsampling)、统计滤波(Statistical Outlier Removal)等。
-
检测框聚类(Box Clustering):在一些场景中,模型可能会输出多个相似的检测框,这些框可能对应着同一个物体。通过聚类算法,可以将这些相似的框归为一类,从而得到更准确的检测结果。
-
对象追踪(Object Tracking):在连续帧的点云数据中,通过追踪算法可以将同一个物体在不同帧之间进行关联,从而实现物体的连续跟踪。常用的追踪算法包括卡尔曼滤波(Kalman Filtering)、匈牙利算法(Hungarian Algorithm)等。
-
检测结果过滤(Detection Result Filtering):根据应用需求,可以对最终的检测结果进行进一步过滤,例如根据置信度阈值进行筛选,去除不满足要求的检测结果。
Pointpillar POST_PROCESSING 配置文件
POST_PROCESSING:
RECALL_THRESH_LIST: [0.3, 0.5, 0.7]
SCORE_THRESH: 0.1
OUTPUT_RAW_SCORE: False
EVAL_METRIC: kitti
NMS_CONFIG:
MULTI_CLASSES_NMS: False
NMS_TYPE: nms_gpu
NMS_THRESH: 0.01
NMS_PRE_MAXSIZE: 4096
NMS_POST_MAXSIZE: 500
-
RECALL_THRESH_LIST: [0.3, 0.5, 0.7]:分别代表 [‘Car’, ‘Pedestrian’, ‘Cyclist’] 这是一个召回阈值列表,用于评估检测结果的召回率。在评估过程中,将计算不同召回阈值下的召回率,并输出相应的指标。
-
SCORE_THRESH: 0.1:这是一个得分阈值,用于过滤检测结果。只有得分高于该阈值的检测结果才会被保留,低于阈值的结果将被丢弃。
-
OUTPUT_RAW_SCORE: False:这是一个布尔值参数,用于指定是否在输出结果中包含原始的检测得分。如果设置为True,则在输出结果中将包含原始得分;如果设置为False,则只输出二值化的检测结果。
-
EVAL_METRIC: kitti:这是一个评估指标的选择,用于衡量检测结果的性能。在这种情况下,选择了Kitti评估指标,该指标通常用于衡量目标检测在Kitti数据集上的性能。
-
NMS_CONFIG:这是一个配置NMS(非极大值抑制)的子模块,用于在检测结果中进行框的合并和过滤。
-
MULTI_CLASSES_NMS: False:这是一个布尔值参数,用于指定是否对多类别进行NMS。如果设置为True,则会对多个类别的检测框进行NMS;如果设置为False,则只对同一类别的检测框进行NMS。
-
NMS_TYPE: nms_gpu:这是NMS算法的选择。在这种情况下,选择了nms_gpu算法,该算法使用GPU加速执行NMS操作。
-
NMS_THRESH: 0.01:这是NMS的阈值,用于控制重叠度的判定。当两个框的重叠度高于该阈值时,较低得分的框将被抑制。
-
NMS_PRE_MAXSIZE: 4096:这是NMS操作之前,每个类别最大保留的检测框数量。如果超过该数量,将根据得分进行排序并截断。
-
NMS_POST_MAXSIZE: 500:这是NMS操作之后,每个类别最大保留的检测框数量。如果超过该数量,将根据得分进行排序并截断。
POST_PROCESSING 代码讲解
代码在OpenPCDet/pcdet/models/detectors/detector3d_template.py下面
def post_processing(self, batch_dict):
"""
Args:
batch_dict:
batch_size:
batch_cls_preds: (B, num_boxes, num_classes | 1) or (N1+N2+..., num_classes | 1)
or [(B, num_boxes, num_class1), (B, num_boxes, num_class2) ...]
multihead_label_mapping: [(num_class1), (num_class2), ...]
batch_box_preds: (B, num_boxes, 7+C) or (N1+N2+..., 7+C)
cls_preds_normalized: indicate whether batch_cls_preds is normalized
batch_index: optional (N1+N2+...)
has_class_labels: True/False
roi_labels: (B, num_rois) 1 .. num_classes
batch_pred_labels: (B, num_boxes, 1)
Returns:
pred_dicts: 一个包含预测结果的列表,每个元素是一个字典,包含了预测框的坐标、得分和类别
recall_dict: 一个包含召回率信息的字典,用于评估检测结果的召回率
"""
post_process_cfg = self.model_cfg.POST_PROCESSING
batch_size = batch_dict['batch_size']
recall_dict = {}
pred_dicts = []
for index in range(batch_size):
# 根据是否包含batch_index来确定box_preds的形状
if batch_dict.get('batch_index', None) is not None:
assert batch_dict['batch_box_preds'].shape.__len__() == 2
batch_mask = (batch_dict['batch_index'] == index)
else:
assert batch_dict['batch_box_preds'].shape.__len__() == 3
batch_mask = index
box_preds = batch_dict['batch_box_preds'][batch_mask]
src_box_preds = box_preds
# 处理分类预测结果
if not isinstance(batch_dict['batch_cls_preds'], list):
cls_preds = batch_dict['batch_cls_preds'][batch_mask]
src_cls_preds = cls_preds
assert cls_preds.shape[1] in [1, self.num_class]
if not batch_dict['cls_preds_normalized']:
cls_preds = torch.sigmoid(cls_preds)
else:
cls_preds = [x[batch_mask] for x in batch_dict['batch_cls_preds']]
src_cls_preds = cls_preds
if not batch_dict['cls_preds_normalized']:
cls_preds = [torch.sigmoid(x) for x in cls_preds]
# 多类别NMS
if post_process_cfg.NMS_CONFIG.MULTI_CLASSES_NMS:
if not isinstance(cls_preds, list):
cls_preds = [cls_preds]
multihead_label_mapping = [torch.arange(1, self.num_class, device=cls_preds[0].device)]
else:
multihead_label_mapping = batch_dict['multihead_label_mapping']
cur_start_idx = 0
pred_scores, pred_labels, pred_boxes = [], [], []
for cur_cls_preds, cur_label_mapping in zip(cls_preds, multihead_label_mapping):
assert cur_cls_preds.shape[1] == len(cur_label_mapping)
cur_box_preds = box_preds[cur_start_idx: cur_start_idx + cur_cls_preds.shape[0]]
cur_pred_scores, cur_pred_labels, cur_pred_boxes = model_nms_utils.multi_classes_nms(
cls_scores=cur_cls_preds, box_preds=cur_box_preds,
nms_config=post_process_cfg.NMS_CONFIG,
score_thresh=post_process_cfg.SCORE_THRESH
)
cur_pred_labels = cur_label_mapping[cur_pred_labels]
pred_scores.append(cur_pred_scores)
pred_labels.append(cur_pred_labels)
pred_boxes.append(cur_pred_boxes)
cur_start_idx += cur_cls_preds.shape[0]
final_scores = torch.cat(pred_scores, dim=0)
final_labels = torch.cat(pred_labels, dim=0)
final_boxes = torch.cat(pred_boxes, dim=0)
else:
# 单类别NMS
cls_preds, label_preds = torch.max(cls_preds, dim=-1)
if batch_dict.get('has_class_labels', False):
label_key = 'roi_labels' if 'roi_labels' in batch_dict else 'batch_pred_labels'
label_preds = batch_dict[label_key][index]
else:
label_preds = label_preds + 1
selected, selected_scores = model_nms_utils.class_agnostic_nms(
box_scores=cls_preds, box_preds=box_preds,
nms_config=post_process_cfg.NMS_CONFIG,
score_thresh=post_process_cfg.SCORE_THRESH
)
if post_process_cfg.OUTPUT_RAW_SCORE:
max_cls_preds, _ = torch.max(src_cls_preds, dim=-1)
selected_scores = max_cls_preds[selected]
final_scores = selected_scores
final_labels = label_preds[selected]
final_boxes = box_preds[selected]
# 生成召回率记录
recall_dict = self.generate_recall_record(
box_preds=final_boxes if 'rois' not in batch_dict else src_box_preds,
recall_dict=recall_dict, batch_index=index, data_dict=batch_dict,
thresh_list=post_process_cfg.RECALL_THRESH_LIST
)
# 构建预测结果字典,并添加到预测结果列表中
record_dict = {
'pred_boxes': final_boxes,
'pred_scores': final_scores,
'pred_labels': final_labels
}
pred_dicts.append(record_dict)
return pred_dicts, recall_dict
代码里面的一些小细节,后续有人问到再讲文章来源:https://www.toymoban.com/news/detail-778191.html
引用
Openpcdet 系列 Pointpillar代码逐行解析
OpenPCDet 环境安装
OpenPCDet KITTI数据加载过程 (Pointpillar模型)
Openpcdet 系列 Pointpillar代码逐行解析之Voxel Feature Encoding (VFE)模块
Openpcdet 系列 Pointpillar代码逐行解析之MAP_TO_BEV模块
Openpcdet 系列 Pointpillar代码逐行解析之BACKBONE_2D模块
Openpcdet 系列 Pointpillar代码逐行解析之检测头(DENSE_HEAD)模块
Openpcdet 系列 Pointpillar代码逐行解析之POST_PROCESSING模块
文章来源地址https://www.toymoban.com/news/detail-778191.html
到了这里,关于Openpcdet 系列 Pointpillar代码逐行解析之POST_PROCESSING模块的文章就介绍完了。如果您还想了解更多内容,请在右上角搜索TOY模板网以前的文章或继续浏览下面的相关文章,希望大家以后多多支持TOY模板网!