写在前面:建议先看完博主的另一篇博客核心组件分析,再去理解整个代码逻辑,结合代码反复阅读,抓住其中面向对象编程的核心思想,祝顺利,欢迎留言评论,博主会定期解答!
整体构建流程
按照数据流过程,训练流程可以简单总结为:
- 获取config配置并初始化各种类的实例化,通过Runner进行全生命周期管理:
(1)Model类初始化,并根据是否多卡训练,进一步对Model类的上层进一步封装,若是分布式(单机多卡或多机多卡)训练,则初始化MMDistributedDataParallel类,若单机训练,则初始化MMDataParallel类;这两个类不仅可以处理 DataContainer 对象,还额外实现了 train_step() 和 val_step() 两个函数,可以被 Runner 调用。
(2)Dataset类初始化,在迭代输出数据的时候需要通过数据 Pipeline 对数据进行各种处理,最典型的处理流是训练中的数据增强操作,测试中的数据预处理等等;将 Sampler(通过 Sampler 采样器可以控制 Dataset 输出的数据顺序,最常用的是随机采样器 RandomSampler。由于 Dataset 中输出的图片大小不一样,为了尽可能减少后续组成 batch 时 pad 的像素个数,MMDetection 引入了分组采样器 GroupSampler 和 DistributedGroupSampler,相当于在 RandomSampler 基础上额外新增了根据图片宽高比进行 group 功能)和 Dataset 都输入给 DataLoader,然后通过 DataLoader 输出已组成 batch 的数据,作为 Model 的输入;
(3)Runner类初始化,它负责管理每一个epoch和iteration的train或val,还负责调用hook实现功能扩展,从而方便地获取、修改和拦截任何生命周期数据流。
(4)Logger、Hook等类的初始化。 - Model 运行,输出 loss 以及其他一些信息,会通过 logger 进行保存或者可视化;
- 根据loss计算梯度并更新权重;
而测试流程就比较简单了,直接对 DataLoader 输出的数据进行前向推理即可,还原到最终原图尺度过程也是在 Model 中完成。
以上就是 MMDetection 框架整体训练和测试抽象流程,上图不仅仅反映了训练和测试数据流,而且还包括了模块和模块之间的调用关系。对于训练而言,最核心部分应该是 Runner,理解了 Runner 的运行流程,也就理解了整个 MMDetection 数据流。
代码解析
训练流程
1、初始化配置、logger、model、datasets、runner等,调用runner.run()函数
#=================== tools/train.py ==================
# 1.初始化配置
cfg = Config.fromfile(args.config)
# 2.判断是否为分布式训练模式
# 3.初始化 logger
logger = get_root_logger(log_file=log_file, log_level=cfg.log_level)
# 4.收集运行环境并且打印,方便排查硬件和软件相关问题
env_info_dict = collect_env()
# 5.初始化 model
model = build_detector(cfg.model, ...)
# 6.初始化 datasets
#=================== mmdet/apis/train.py ==================
# 1.初始化 data_loaders ,内部会初始化 GroupSampler、DistributedSampler、DistributedGroupSampler
data_loader = DataLoader(dataset,...)
# 2.基于是否使用分布式训练,初始化对应的 DataParallel
if distributed:
model = MMDistributedDataParallel(...)
else:
model = MMDataParallel(...)
# 3.初始化 runner
runner = EpochBasedRunner(...)
# 4.注册必备 hook
runner.register_training_hooks(cfg.lr_config, optimizer_config,
cfg.checkpoint_config, cfg.log_config,
cfg.get('momentum_config', None))
# 5.如果需要 val,则还需要注册 EvalHook
runner.register_hook(eval_hook(val_dataloader, **eval_cfg))
# 6.注册用户自定义 hook
runner.register_hook(hook, priority=priority)
# 7.权重恢复和加载
if cfg.resume_from:
runner.resume(cfg.resume_from)
elif cfg.load_from:
runner.load_checkpoint(cfg.load_from)
# 8.运行,开始训练
runner.run(data_loaders, cfg.workflow, cfg.total_epochs)
runner 对象内部的 run 方式是一个通用方法,可以运行任何 workflow,目前常用的主要是 train 和 val。
- 当配置为:workflow = [(‘train’, 1)],表示仅仅进行 train workflow,也就是迭代训练
- 当配置为:workflow = [(‘train’, n),(‘val’, 1)],表示先进行 n 个 epoch 的训练,然后再进行1个 epoch 的验证,然后循环往复,如果写成 [(‘val’, 1),(‘train’, n)] 表示先进行验证,然后才开始训练
2、调用runner中的 train() 或者 val()
当进入对应的 workflow,则会调用 runner 里面的 train() 或者 val(),表示进行一次 epoch 迭代,如下所示:
def train(self, data_loader, **kwargs):
self.model.train()
self.mode = 'train'
self.data_loader = data_loader
# 在每一次epoch训练前调用hook
self.call_hook('before_train_epoch')
for i, data_batch in enumerate(self.data_loader):
# 在每一次iter训练前调用hook
self.call_hook('before_train_iter')
self.run_iter(data_batch, train_mode=True)
# 在每一次iter训练后调用hook
self.call_hook('after_train_iter')
# 在每一次epoch训练后调用hook
self.call_hook('after_train_epoch')
def val(self, data_loader, **kwargs):
self.model.eval()
self.mode = 'val'
self.data_loader = data_loader
# 在每一次epoch验证前调用hook
self.call_hook('before_val_epoch')
for i, data_batch in enumerate(self.data_loader):
# 在每一次iter验证前调用hook
self.call_hook('before_val_iter')
with torch.no_grad():
self.run_iter(data_batch, train_mode=False)
# 在每一次iter验证后调用hook
self.call_hook('after_val_iter')
# 在每一次epoch验证后调用hook
self.call_hook('after_val_epoch')
在每一个epoch不断迭代循环,实现每一个iteration,核心函数实际上是 self.run_iter(),如下:
def run_iter(self, data_batch, train_mode, **kwargs):
if train_mode:
# 对于每次迭代,最终是调用如下函数
outputs = self.model.train_step(data_batch,...)
else:
# 对于每次迭代,最终是调用如下函数
outputs = self.model.val_step(data_batch,...)
if 'log_vars' in outputs:
self.log_buffer.update(outputs['log_vars'],...)
self.outputs = outputs
3、runner 中调用 train_step 或者 val_step
#=================== mmcv/runner/epoch_based_runner.py ==================
if train_mode:
outputs = self.model.train_step(data_batch,...)
else:
outputs = self.model.val_step(data_batch,...)
实际上,首先会调用 DataParallel 中的 train_step 或者 val_step ,其具体调用流程为:
# 非分布式训练
#=================== mmcv/parallel/data_parallel.py/MMDataParallel ==================
def train_step(self, *inputs, **kwargs):
if not self.device_ids:
inputs, kwargs = self.scatter(inputs, kwargs, [-1])
# 此时才是调用 model 本身的 train_step
return self.module.train_step(*inputs, **kwargs)
# 单 gpu 模式
inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids)
# 此时才是调用 model 本身的 train_step
return self.module.train_step(*inputs[0], **kwargs[0])
# val_step 也是的一样逻辑
def val_step(self, *inputs, **kwargs):
inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids)
# 此时才是调用 model 本身的 val_step
return self.module.val_step(*inputs[0], **kwargs[0])
可以发现,在调用 model 本身的 train_step 前,需要额外调用 scatter 函数,该函数的作用是处理 DataContainer 格式数据,使其能够组成 batch,否则程序会报错。
如果是分布式训练,则调用的实际上是 mmcv/parallel/distributed.py/MMDistributedDataParallel,最终调用的依然是 model 本身的 train_step 或者 val_step。
4、调用 model 中的 train_step 或者 val_step
#=================== mmdet/models/detectors/base.py/BaseDetector ==================
def train_step(self, data, optimizer):
# 实例():调用__call__()函数,在函数内部会调用本类自身的 forward 方法
losses = self(**data)
# 解析 loss
loss, log_vars = self._parse_losses(losses)
# 返回字典对象
outputs = dict(
loss=loss, log_vars=log_vars, num_samples=len(data['img_metas']))
return outputs
def forward(self, img, img_metas, return_loss=True, **kwargs):
if return_loss:
# 训练模式
return self.forward_train(img, img_metas, **kwargs)
else:
# 测试模式
return self.forward_test(img, img_metas, **kwargs)
forward_train 和 forward_test 需要在不同的算法子类中实现,输出是 Loss 或者 预测结果。
5、调用子类中的 forward_train 方法
目前提供了两个具体子类,TwoStageDetector 和 SingleStageDetector ,用于实现 two-stage 和 single-stage 算法。
对于 TwoStageDetector 而言,其核心逻辑是:
#============= mmdet/models/detectors/two_stage.py/TwoStageDetector ============
def forward_train(...):
# 先进行 backbone+neck 的特征提取
x = self.extract_feat(img)
losses = dict()
# RPN forward and loss
if self.with_rpn:
# 训练 RPN
proposal_cfg = self.train_cfg.get('rpn_proposal',
self.test_cfg.rpn)
# 主要是调用 rpn_head 内部的 forward_train 方法
rpn_losses, proposal_list = self.rpn_head.forward_train(x,...)
losses.update(rpn_losses)
else:
proposal_list = proposals
# 第二阶段,主要是调用 roi_head 内部的 forward_train 方法
roi_losses = self.roi_head.forward_train(x, ...)
losses.update(roi_losses)
return losses
对于 SingleStageDetector 而言,其核心逻辑是:
#============= mmdet/models/detectors/single_stage.py/SingleStageDetector ============
def forward_train(...):
super(SingleStageDetector, self).forward_train(img, img_metas)
# 先进行 backbone+neck 的特征提取
x = self.extract_feat(img)
# 主要是调用 bbox_head 内部的 forward_train 方法
losses = self.bbox_head.forward_train(x, ...)
return losses
如果在TwoStageDetector 和 SingleStageDetector基础上封装了其他类,就会调用新类中的forward_train函数。文章来源:https://www.toymoban.com/news/detail-544444.html
测试流程
对于测试逻辑由于比较简单,就不详细描述了,简单来说测试流程下不需要 runner,直接加载训练好的权重,然后进行 model 推理即可,下面简要概述:文章来源地址https://www.toymoban.com/news/detail-544444.html
- 调用 MMDataParallel 或 MMDistributedDataParallel 中的 forward 方法;
- 调用 base.py 中的 forward 方法;
- 调用 base.py 中的 self.forward_test 方法;
- 如果是单尺度测试,则会调用 TwoStageDetector 或 SingleStageDetector 中的 simple_test 方法,如果是多尺度测试,则调用 aug_test 方法;
- 最终调用的是每个具体算法 Head 模块的 simple_test 或者 aug_test 方法。
到了这里,关于MMDetection学习笔记(五):整体构建流程与代码解析的文章就介绍完了。如果您还想了解更多内容,请在右上角搜索TOY模板网以前的文章或继续浏览下面的相关文章,希望大家以后多多支持TOY模板网!