MMDetection学习笔记(五):整体构建流程与代码解析

这篇具有很好参考价值的文章主要介绍了MMDetection学习笔记(五):整体构建流程与代码解析。希望对大家有所帮助。如果存在错误或未考虑完全的地方,请大家不吝赐教,您也可以点击"举报违法"按钮提交疑问。

写在前面:建议先看完博主的另一篇博客核心组件分析,再去理解整个代码逻辑,结合代码反复阅读,抓住其中面向对象编程的核心思想,祝顺利,欢迎留言评论,博主会定期解答!

整体构建流程

按照数据流过程,训练流程可以简单总结为:

  1. 获取config配置并初始化各种类的实例化,通过Runner进行全生命周期管理:
    (1)Model类初始化,并根据是否多卡训练,进一步对Model类的上层进一步封装,若是分布式(单机多卡或多机多卡)训练,则初始化MMDistributedDataParallel类,若单机训练,则初始化MMDataParallel类;这两个类不仅可以处理 DataContainer 对象,还额外实现了 train_step() 和 val_step() 两个函数,可以被 Runner 调用。
    (2)Dataset类初始化,在迭代输出数据的时候需要通过数据 Pipeline 对数据进行各种处理,最典型的处理流是训练中的数据增强操作,测试中的数据预处理等等;将 Sampler(通过 Sampler 采样器可以控制 Dataset 输出的数据顺序,最常用的是随机采样器 RandomSampler。由于 Dataset 中输出的图片大小不一样,为了尽可能减少后续组成 batch 时 pad 的像素个数,MMDetection 引入了分组采样器 GroupSampler 和 DistributedGroupSampler,相当于在 RandomSampler 基础上额外新增了根据图片宽高比进行 group 功能)和 Dataset 都输入给 DataLoader,然后通过 DataLoader 输出已组成 batch 的数据,作为 Model 的输入;
    (3)Runner类初始化,它负责管理每一个epoch和iteration的train或val,还负责调用hook实现功能扩展,从而方便地获取、修改和拦截任何生命周期数据流。
    (4)Logger、Hook等类的初始化。
  2. Model 运行,输出 loss 以及其他一些信息,会通过 logger 进行保存或者可视化;
  3. 根据loss计算梯度并更新权重;

而测试流程就比较简单了,直接对 DataLoader 输出的数据进行前向推理即可,还原到最终原图尺度过程也是在 Model 中完成。

以上就是 MMDetection 框架整体训练和测试抽象流程,上图不仅仅反映了训练和测试数据流,而且还包括了模块和模块之间的调用关系。对于训练而言,最核心部分应该是 Runner,理解了 Runner 的运行流程,也就理解了整个 MMDetection 数据流。

代码解析

训练流程

1、初始化配置、logger、model、datasets、runner等,调用runner.run()函数

#=================== tools/train.py ==================
# 1.初始化配置
cfg = Config.fromfile(args.config)

# 2.判断是否为分布式训练模式

# 3.初始化 logger
logger = get_root_logger(log_file=log_file, log_level=cfg.log_level)

# 4.收集运行环境并且打印,方便排查硬件和软件相关问题
env_info_dict = collect_env()

# 5.初始化 model
model = build_detector(cfg.model, ...)

# 6.初始化 datasets

#=================== mmdet/apis/train.py ==================
# 1.初始化 data_loaders ,内部会初始化 GroupSampler、DistributedSampler、DistributedGroupSampler
data_loader = DataLoader(dataset,...)

# 2.基于是否使用分布式训练,初始化对应的 DataParallel
if distributed:
  model = MMDistributedDataParallel(...)
else:
  model = MMDataParallel(...)

# 3.初始化 runner
runner = EpochBasedRunner(...)

# 4.注册必备 hook
runner.register_training_hooks(cfg.lr_config, optimizer_config,
                               cfg.checkpoint_config, cfg.log_config,
                               cfg.get('momentum_config', None))

# 5.如果需要 val,则还需要注册 EvalHook           
runner.register_hook(eval_hook(val_dataloader, **eval_cfg))

# 6.注册用户自定义 hook
runner.register_hook(hook, priority=priority)

# 7.权重恢复和加载
if cfg.resume_from:
    runner.resume(cfg.resume_from)
elif cfg.load_from:
    runner.load_checkpoint(cfg.load_from)

# 8.运行,开始训练
runner.run(data_loaders, cfg.workflow, cfg.total_epochs)

runner 对象内部的 run 方式是一个通用方法,可以运行任何 workflow,目前常用的主要是 train 和 val。

  • 当配置为:workflow = [(‘train’, 1)],表示仅仅进行 train workflow,也就是迭代训练
  • 当配置为:workflow = [(‘train’, n),(‘val’, 1)],表示先进行 n 个 epoch 的训练,然后再进行1个 epoch 的验证,然后循环往复,如果写成 [(‘val’, 1),(‘train’, n)] 表示先进行验证,然后才开始训练

2、调用runner中的 train() 或者 val()
当进入对应的 workflow,则会调用 runner 里面的 train() 或者 val(),表示进行一次 epoch 迭代,如下所示:

def train(self, data_loader, **kwargs):
    self.model.train()
    self.mode = 'train'
    self.data_loader = data_loader
    # 在每一次epoch训练前调用hook
    self.call_hook('before_train_epoch')
    for i, data_batch in enumerate(self.data_loader):
    	# 在每一次iter训练前调用hook
        self.call_hook('before_train_iter')
        self.run_iter(data_batch, train_mode=True)
        # 在每一次iter训练后调用hook
        self.call_hook('after_train_iter')
	# 在每一次epoch训练后调用hook
    self.call_hook('after_train_epoch')


def val(self, data_loader, **kwargs):
    self.model.eval()
    self.mode = 'val'
    self.data_loader = data_loader
    # 在每一次epoch验证前调用hook
    self.call_hook('before_val_epoch')
    for i, data_batch in enumerate(self.data_loader):
    	# 在每一次iter验证前调用hook
        self.call_hook('before_val_iter')
        with torch.no_grad():
            self.run_iter(data_batch, train_mode=False)
        # 在每一次iter验证后调用hook
        self.call_hook('after_val_iter')
   	# 在每一次epoch验证后调用hook
    self.call_hook('after_val_epoch')

在每一个epoch不断迭代循环,实现每一个iteration,核心函数实际上是 self.run_iter(),如下:

def run_iter(self, data_batch, train_mode, **kwargs):
    if train_mode:
        # 对于每次迭代,最终是调用如下函数
        outputs = self.model.train_step(data_batch,...)
    else:
        # 对于每次迭代,最终是调用如下函数
        outputs = self.model.val_step(data_batch,...)

    if 'log_vars' in outputs:
        self.log_buffer.update(outputs['log_vars'],...)
    self.outputs = outputs

3、runner 中调用 train_step 或者 val_step

#=================== mmcv/runner/epoch_based_runner.py ==================
if train_mode:
    outputs = self.model.train_step(data_batch,...)
else:
    outputs = self.model.val_step(data_batch,...)

实际上,首先会调用 DataParallel 中的 train_step 或者 val_step ,其具体调用流程为:

# 非分布式训练
#=================== mmcv/parallel/data_parallel.py/MMDataParallel ==================
def train_step(self, *inputs, **kwargs):
    if not self.device_ids:
        inputs, kwargs = self.scatter(inputs, kwargs, [-1])
        # 此时才是调用 model 本身的 train_step
        return self.module.train_step(*inputs, **kwargs)
    # 单 gpu 模式
    inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids)
    # 此时才是调用 model 本身的 train_step
    return self.module.train_step(*inputs[0], **kwargs[0])

# val_step 也是的一样逻辑
def val_step(self, *inputs, **kwargs):
    inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids)
    # 此时才是调用 model 本身的 val_step
    return self.module.val_step(*inputs[0], **kwargs[0])

可以发现,在调用 model 本身的 train_step 前,需要额外调用 scatter 函数,该函数的作用是处理 DataContainer 格式数据,使其能够组成 batch,否则程序会报错。

如果是分布式训练,则调用的实际上是 mmcv/parallel/distributed.py/MMDistributedDataParallel,最终调用的依然是 model 本身的 train_step 或者 val_step。

4、调用 model 中的 train_step 或者 val_step

#=================== mmdet/models/detectors/base.py/BaseDetector ==================
def train_step(self, data, optimizer):
    # 实例():调用__call__()函数,在函数内部会调用本类自身的 forward 方法
    losses = self(**data)
    # 解析 loss
    loss, log_vars = self._parse_losses(losses)
    # 返回字典对象
    outputs = dict(
        loss=loss, log_vars=log_vars, num_samples=len(data['img_metas']))
    return outputs

def forward(self, img, img_metas, return_loss=True, **kwargs):
    if return_loss:
        # 训练模式
        return self.forward_train(img, img_metas, **kwargs)
    else:
        # 测试模式
        return self.forward_test(img, img_metas, **kwargs)

forward_train 和 forward_test 需要在不同的算法子类中实现,输出是 Loss 或者 预测结果。

5、调用子类中的 forward_train 方法

目前提供了两个具体子类,TwoStageDetector 和 SingleStageDetector ,用于实现 two-stage 和 single-stage 算法。

对于 TwoStageDetector 而言,其核心逻辑是:

#============= mmdet/models/detectors/two_stage.py/TwoStageDetector ============
def forward_train(...):
    # 先进行 backbone+neck 的特征提取
    x = self.extract_feat(img)
    losses = dict()
    # RPN forward and loss
    if self.with_rpn:
        # 训练 RPN
        proposal_cfg = self.train_cfg.get('rpn_proposal',
                                          self.test_cfg.rpn)
        # 主要是调用 rpn_head 内部的 forward_train 方法
        rpn_losses, proposal_list = self.rpn_head.forward_train(x,...)
        losses.update(rpn_losses)
    else:
        proposal_list = proposals
    # 第二阶段,主要是调用 roi_head 内部的 forward_train 方法
    roi_losses = self.roi_head.forward_train(x, ...)
    losses.update(roi_losses)
    return losses

对于 SingleStageDetector 而言,其核心逻辑是:

#============= mmdet/models/detectors/single_stage.py/SingleStageDetector ============
def forward_train(...):
    super(SingleStageDetector, self).forward_train(img, img_metas)
    # 先进行 backbone+neck 的特征提取
    x = self.extract_feat(img)
    # 主要是调用 bbox_head 内部的 forward_train 方法
    losses = self.bbox_head.forward_train(x, ...)
    return losses

如果在TwoStageDetector 和 SingleStageDetector基础上封装了其他类,就会调用新类中的forward_train函数。

测试流程

对于测试逻辑由于比较简单,就不详细描述了,简单来说测试流程下不需要 runner,直接加载训练好的权重,然后进行 model 推理即可,下面简要概述:文章来源地址https://www.toymoban.com/news/detail-544444.html

  1. 调用 MMDataParallel 或 MMDistributedDataParallel 中的 forward 方法;
  2. 调用 base.py 中的 forward 方法;
  3. 调用 base.py 中的 self.forward_test 方法;
  4. 如果是单尺度测试,则会调用 TwoStageDetector 或 SingleStageDetector 中的 simple_test 方法,如果是多尺度测试,则调用 aug_test 方法;
  5. 最终调用的是每个具体算法 Head 模块的 simple_test 或者 aug_test 方法。

到了这里,关于MMDetection学习笔记(五):整体构建流程与代码解析的文章就介绍完了。如果您还想了解更多内容,请在右上角搜索TOY模板网以前的文章或继续浏览下面的相关文章,希望大家以后多多支持TOY模板网!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处: 如若内容造成侵权/违法违规/事实不符,请点击违法举报进行投诉反馈,一经查实,立即删除!

领支付宝红包 赞助服务器费用

相关文章

  • MMDetection学习笔记(四):核心组件分析

    此篇博客注重分析了MMDetection中三大核心组件:Registry、Hook和Runner。 Registry 机制其实维护的是一个全局字典,实现字符串到类的映射。通过 Registry 类,用户可以通过config中字符串的方式实例化任何想要的类(或模块)。Registry的优点在于:解耦性强、可扩展性强,代码更易理

    2024年02月12日
    浏览(44)
  • AI实战营第二期 第六节 《MMDetection代码课》——笔记7

    MMDetection 是被广泛使用的检测工具箱,包括了目标检侧、实例分割、全景分割等多个通用检测方向,并支持了 75+ 个主流和前沿模型, 为用户提供超过 440+ 个预训练模型, 在学术研究和工业落地中拥有广泛应用。该恇架的主要特点为: 模块化设计。MMDetection 将检测框架解耦成不

    2024年02月08日
    浏览(53)
  • [BEV] 学习笔记之BEVDet(原理+代码解析)

    前言 基于LSS的成功,鉴智机器人提出了BEVDet,目前来到了2.0版本,在nuscences排行榜中以mAP=0.586暂列第一名。本文将对BEVDet的原理进行简要说明,然后结合代码对BEVDet进深度解析。 repo: https://github.com/HuangJunJie2017/BEVDet paper:https://arxiv.org/abs/2211.17111 欢迎进入BEV感知交流群,一起

    2024年02月05日
    浏览(48)
  • 数学建模学习(3):综合评价类问题整体解析及分析步骤

    对物体进行评价,用具体的分值评价它们的优劣 选这两人其中之一当男朋友,你会选谁? 不同维度的权重会产生不同的结果 所以找到 每个维度的权重是最核心的问题 0.25 供应商 ID 可靠性 指标 2 指标 3 指标 4 指标 5 1 1 4 100 56 1000 2 2 6 105 55 2000 正向指标处理:即越大越好的指标

    2024年02月16日
    浏览(52)
  • 【现代机器人学】学习笔记十三:配套代码解析

    最近一直忙于工作,每天都在写一些业务代码。而目前工程中的技术栈并没有使用旋量这一套机器人理论系统,因此时间长了自己都忘记了。 于是决定把这本书配套的代码内容也过一遍,查漏补缺,把这本书的笔记内容完结一下。 代码来源于github:https://github.com/NxRLab/Moder

    2024年02月12日
    浏览(45)
  • Kimball维度模型之构建数据仓库流程解析

        目录 一 数据建模概述 二 构建数据仓库项目应该设计哪些模型表? 三 数据仓库项目的模型表应该如何设计? 三 总结      在开始学习之前请先思考两个问题?在你的脑海里对这两个问题是有已经有了清晰的答案? 构建数据仓库项目应该设计哪些模型表? 数据仓库项

    2024年03月22日
    浏览(50)
  • Android车载学习笔记1——车载整体系统简介

             汽车操作系统包括安全车载操作系统、智能驾驶操作系统和智能座舱操作系统。 1. 安全车载操作系统         安全车载操作系统主要面向经典车辆控制领域,如动力系统、底盘系统和车身系统等,该类操作系统对实时性和安全性要求极高,生态发展已趋于成

    2024年02月06日
    浏览(53)
  • 【深入Scrapy实战】从登录到数据解析构建完整爬虫流程

    【作者主页】: 吴秋霖 【作者介绍】:Python领域优质创作者、阿里云博客专家、华为云享专家。长期致力于Python与爬虫领域研究与开发工作! 【作者推荐】:对JS逆向感兴趣的朋友可以关注《爬虫JS逆向实战》,对分布式爬虫平台感兴趣的朋友可以关注《分布式爬虫平台搭建

    2024年02月04日
    浏览(51)
  • PSCAD学习笔记(2)——python调用PSCAD自动化库代码解析:组件控制

    该学习笔记结合官方文件和个人学习见解撰写,主要分享一些常见实用功能,欢迎讨论、补充、指正。PSCAD相关免费学习资源实属稀缺,如果本文对您有所帮助,麻烦点赞评论支持一下。您的支持是我更新的动力。 PSCAD版本:4.6.3 python版本:3.7 mhrc-automation版本:1.2.4 python编辑

    2024年02月22日
    浏览(88)
  • 基于LIDC-IDRI肺结节肺癌数据集的人工智能深度学习分类良性和恶性肺癌(Python 全代码)全流程解析(二)

    第一部分内容的传送门 环境配置建议使用anaconda进行配置。核心的配置是keras和tensorflow的版本要匹配。 环境配置如下: tensorboard 1.13.1 tensorflow 1.13.1 Keras 2.2.4 numpy 1.21.5 opencv-python 4.6.0.66 python 3.7 数据集的预处理分为两个关键步骤。首先是图片处理,我们使用cv2库将图片转换为

    2024年04月29日
    浏览(39)

觉得文章有用就打赏一下文章作者

支付宝扫一扫打赏

博客赞助

微信扫一扫打赏

请作者喝杯咖啡吧~博客赞助

支付宝扫一扫领取红包,优惠每天领

二维码1

领取红包

二维码2

领红包