核心组件分析
此篇博客注重分析了MMDetection中三大核心组件:Registry、Hook和Runner。
Registry
Registry 机制其实维护的是一个全局字典,实现字符串到类的映射。通过 Registry 类,用户可以通过config中字符串的方式实例化任何想要的类(或模块)。Registry的优点在于:解耦性强、可扩展性强,代码更易理解。
MMCV中Registry类的实现源码:
class Registry:
def __init__(self, name):
# 可实现注册类细分功能
self._name = name
# 内部核心内容,维护所有的已经注册好的 class
self._module_dict = dict()
def _register_module(self, module_class, module_name=None, force=False):
if not inspect.isclass(module_class):
raise TypeError('module must be a class, '
f'but got {type(module_class)}')
if module_name is None:
module_name = module_class.__name__
if not force and module_name in self._module_dict:
raise KeyError(f'{module_name} is already registered '
f'in {self.name}')
# 最核心代码
self._module_dict[module_name] = module_class
# 装饰器函数
def register_module(self, name=None, force=False, module=None):
if module is not None:
# 如果已经是 module,那就知道 增加到字典中即可
self._register_module(
module_class=module, module_name=name, force=force)
return module
# 最标准用法
# use it as a decorator: @x.register_module()
def _register(cls):
self._register_module(
module_class=cls, module_name=name, force=force)
return cls
return _register
在 MMCV 中所有的类实例化都是通过build_from_cfg
函数实现,做的事情非常简单,就是给定module_name
,然后从 self._module_dict
提取即可。
def build_from_cfg(cfg, registry, default_args=None):
args = cfg.copy()
if default_args is not None:
for name, value in default_args.items():
args.setdefault(name, value)
obj_type = args.pop('type') # 注册 str 类名
if is_str(obj_type):
# 相当于 self._module_dict[obj_type]
obj_cls = registry.get(obj_type)
if obj_cls is None:
raise KeyError(
f'{obj_type} is not in the {registry.name} registry')
# 如果已经实例化了,那就直接返回
elif inspect.isclass(obj_type):
obj_cls = obj_type
else:
raise TypeError(
f'type must be a str or valid type, but got {type(obj_type)}')
# 最终初始化对于类,并且返回,就完成了一个类的实例化过程
return obj_cls(**args)
一个完整的使用例子如下:
# registry
CONVERTERS = Registry('converter')
@CONVERTERS.register_module()
class Converter1(object):
def __init__(self, a, b):
self.a = a
self.b = b
# config
converter_cfg = dict(type='Converter1', a=a_value, b=b_value)
converter = build_from_cfg(converter_cfg,CONVERTERS)
Hook
Hook的定义
在 wiki 百科中定义如下:
钩子编程(hooking),也称作“挂钩”,是计算机程序设计术语,指通过拦截软件模块间的函数调用、消息传递、事件传递来修改或扩展操作系统、应用程序或其他软件组件的行为的各种技术。处理被拦截的函数调用、事件、消息的代码,被称为钩子(hook)
简单来说,Hook机制可以在代码运行的整个生命周期中无侵入地拓展功能。Hook 机制在 OpenMMLab 系列框架中应用非常广泛,结合 Runner 类可以实现对训练过程的整个生命周期进行管理。同时内置了多种 Hook,通过注册的形式注入 Runner 中实现了丰富的扩展功能,例如模型权重保存、日志记录、lr超参数的调整等等。
Hook的调用机制
在MMDetection中,Hook 是可以注册进 Runner 中,不同类型的 Hook 实现了不同的生命周期方法从而完成不同的功能,以一个典型的训练过程为例,EpochBasedRunner(以 epoch 为单位) 中生命周期方法如下所示:
# 开始运行时调用
before_run()
while self.epoch < self._max_epochs:
# 开始 epoch 迭代前调用
before_train_epoch()
for i, data_batch in enumerate(self.data_loader):
# 开始一次(iteration)迭代前调用
before_train_iter()
self.model.train_step()
# 经过一次(iteration)迭代后调用
after_train_iter()
# 经过一个 epoch 迭代后调用
after_train_epoch()
# 运行完成前调用
after_run()
只要注册的 Hook 对象实现了某一个或者某几个生命周期方法,当 Runner 运行到预定义的位点时候就会调用对应的 Hook 中方法。
Hook的分类与用法
MMCV中实现的Hook有默认Hook和定制Hook,默认 Hook不需要用户自行注册,用户通过 (hook 名)_config 配置对应参数即可;而对于定制 Hook,则需要用户手动注册或者通过配置方式注册进去。
对于默认 Hook,在 MMDetection 框架训练过程中,其注册代码为:
runner.register_training_hooks(cfg.lr_config, optimizer_config,
cfg.checkpoint_config, cfg.log_config,
cfg.get('momentum_config', None))
register_training_hooks
函数的接收参数其实是字典参数,Runner 内部会根据配置自动生成对应的 Hook 实例,典型的 lr_config 为:
lr_config = dict(
policy='step',
warmup='linear',
warmup_iters=500,
warmup_ratio=0.001,
step=[16, 22])
对于定制类 Hook,其注册源码如下:
# user-defined hooks
if cfg.get('custom_hooks', None):
custom_hooks = cfg.custom_hooks
for hook_cfg in cfg.custom_hooks:
hook_cfg = hook_cfg.copy()
priority = hook_cfg.pop('priority', 'NORMAL')
# 通过配置实例化定制 hook
hook = build_from_cfg(hook_cfg, HOOKS)
# 注册
runner.register_hook(hook, priority=priority)
以 EMAHook 为例,其 .py 配置文件应该写成:
custom_hooks=[dict(type='EMAHook')]
下面对一些比较通用的、常用的 Hook 进行功能简析:
- CheckpointHook
CheckpointHook 主要是对模型参数进行保存,如果是分布式多卡训练,则仅仅会在 master 进程保存。同时可以通过max_keep_ckpts
参数设置最多保存多少个权重文件,早期额外的权重会自动删除。
如果以 epoch 为单位进行保存,则该 Hook 实现after_train_epoch
方法即可,否则仅实现after_train_iter
方法即可。 - LrUpdaterHook
LrUpdaterHook 用于学习率调度,为了统一代码风格以及方便扩展,MMDetection 等训练框架并没有直接继承 PyTorch 提供的学习率调度器,而是通过 LrUpdaterHook 实现。
如果是以 iter 为单位,则仅仅需要在before_train_iter
方法中实现学习率调度功能,如果是以 epoch 为单位,则还需要在before_train_epoch
中实现相关操作。简单来说要实时改变学习率。 - OptimizerHook
OptimizerHook 功能比较简单:梯度反向传播加上参数更新,如果指定了梯度裁剪参数,则可以进行梯度裁剪。 - ClosureHook
ClosureHook 比较特殊,他的主要功能是提供最简洁的函数注册。
可以想象一个场景:在训练过程中,想知道目前的迭代次数,在目前框架体系下最优雅的实现方式是:用户自己写一个获取 iter 的 Hook 类,然后在配置文件中通过custom_hooks
注册进去,该类的代码如下所示:
可以发现你需要做如下事情:@HOOKS.register_module() class GetIterHook(Hook): def after_train_iter(self, runner): print(runner.iter)
(1)写一个 GetIterHook 类,继承自 Hook;
(2)在类上方加上 @HOOKS.register_module();
(3)在对应的 init.py 文件中进行导入;
(4)将该 Hook 注册到 Runner 中。
需要完成三个步骤,但是实际上我只是想 print 而已,比较繁琐,而 ClosureHook 的作用就是为了简化流程。你现在要做的事情如下所示:
(1)定义如上函数;def getiter(runner): print(runner.iter)
(2)作为参数输入给 ClosureHook,并且实例化 ClosureHook(‘after_train_iter’, getiter);
(3)将该 Hook 注册到 Runner 中。
ClosureHook 主要用于一些非常简单的 Hook,但是又不想重新定义一个类来实现,此时就可以通过定义函数,然后传递给 ClosureHook 即可。
Runner
Runner负责OpenMMLab中所有框架pipeline的过程调度,提供了 以Epoch 和 Iter 为基础的迭代模式以满足不同场景,例如 MMDetection 默认采用 Epoch (配置文件中相关参数都是以 Epoch 为单位),而 MMSegmentation 默认采用 Iter (配置文件中相关参数都是以 Iter 为单位)。配合各类 Hook,以一种优雅的方式实现功能的扩展。
Runner 的使用过程可以分成 4 个步骤:
- Runner 对象初始化;
- 注册各类 Hook 到 Runner 中;
- 调用 Runner 的 resume 或者 load_checkpoint 方法对权重进行加载;
- 运行给定的pipeline工作流。
Runner 初始化
考虑到 Epoch 和 Iter 模式有很多共有逻辑,为了复用,抽象出一个 BaseRunner。BaseRunner 初始化是一个常规初始化过程,其参数如下:
def __init__(self,
model,
batch_processor=None, # 已废弃
optimizer=None,
work_dir=None,
logger=None,
meta=None, # 提供了该参数,则会保存到 ckpt 中
max_iters=None, # 这两个参数非常关键,如果没有给定,则内部自己计算
max_epochs=None):
注册 Hook
register_training_hooks
,注册默认Hook:
def register_training_hooks(self,
lr_config, # lr相关
optimizer_config=None, # 优化器相关
checkpoint_config=None, # ckpt 保存相关
log_config=None, # 日志记录相关
momentum_config=None, # momentum 相关
timer_config=dict(type='IterTimerHook')) # 迭代时间统计
register_hook
,上面以外的其他所有 Hook,都是通过本方式进行注册,例如 eval_hook、custom_hooks 和 DistSamplerSeedHook 等等:
def register_hook(self, hook, priority='NORMAL'):
# 获取优先级
priority = get_priority(priority)
hook.priority = priority
# 基于优先级计算当前 hook 插入位置
inserted = False
for i in range(len(self._hooks) - 1, -1, -1):
if priority >= self._hooks[i].priority:
self._hooks.insert(i + 1, hook)
inserted = True
break
if not inserted:
self._hooks.insert(0, hook)
resume 或者 load_checkpoint
resume 方法用于训练过程中停止然后恢复训练时加载权重,而 load_checkpoint 仅仅是加载预训练权重,这个预训练权重可以来自官方,也可以来自自己训练后的权重,如果有 key 不匹配的参数则会自动跳过。
run
run 方法调用后才是真正开启工作流,并且由于 Epoch 和 Iter 模式流程不一样,所以在各自子类实现。
(1) EpochBasedRunner run
def run(self,
data_loaders, # dataloader 列表
workflow, # 工作流列表,长度需要和 data_loaders 一致
max_epochs=None,
**kwargs):
- 假设只想运行训练工作流,则可以设置 workflow = [(‘train’, 1)],表示 data_loader 中的数据进行迭代训练
- 假设想运行训练和验证工作流,则可以设置 workflow = [(‘train’, 3), (‘val’,1)],表示先训练 3 个 epoch ,然后切换到 val 工作流,运行 1 个 epoch,然后循环,直到训练 epoch 次数达到指定值
- 工作流设置非常自由,例如你可以先验证再训练 workflow = [(‘val’, 1), (‘train’,1)]
需要注意的是:如果工作流有两个,那么 data_loaders 中也需要提供两个 dataloader。其核心逻辑如下:
def run(self, data_loaders, workflow, max_epochs=None, **kwargs):
assert isinstance(data_loaders, list)
assert mmcv.is_list_of(workflow, tuple)
assert len(data_loaders) == len(workflow)
# epoch 模式,需要自动计算出 _max_iters
for i, flow in enumerate(workflow):
mode, epochs = flow
if mode == 'train':
self._max_iters = self._max_epochs * len(data_loaders[i])
break
# 调用注册到 runner 中的所有 hook 的 before_run 方法,表示开启 run 前
self.call_hook('before_run')
# 如果没有达到退出条件,就一直运行工作流
while self.epoch < self._max_epochs:
# 遍历工作流
for i, flow in enumerate(workflow):
# 模式,和当前工作流需要运行的 epoch 次数
mode, epochs = flow
epoch_runner = getattr(self, mode)
for _ in range(epochs):
if mode == 'train' and self.epoch >= self._max_epochs:
break
# 开始一个 epoch 的迭代
epoch_runner(data_loaders[i], **kwargs)
time.sleep(1) # wait for some hooks like loggers to finish
# 调用注册到 runner 中的所有 hook 的 after_run 方法,表示结束 run 后
self.call_hook('after_run')
run 方法中定义的是通用工作流切换流程,真正完成一个 epoch 工作流是调用了工作流函数。目前支持 train 和 val 两个工作流,那么 epoch_runner(data_loaders[i], **kwargs)
调用的实际上是 train 或者 val 方法:
# train 和 val 方法逻辑非常相似
def train(self, data_loader, **kwargs):
self.model.train()
self.mode = 'train'
self.data_loader = data_loader
self._max_iters = self._max_epochs * len(self.data_loader)
self.call_hook('before_train_epoch')
time.sleep(2) # Prevent possible deadlock during epoch transition
for i, data_batch in enumerate(self.data_loader):
self._inner_iter = i
self.call_hook('before_train_iter')
self.run_iter(data_batch, train_mode=True)
self.call_hook('after_train_iter')
self._iter += 1
self.call_hook('after_train_epoch')
self._epoch += 1
@torch.no_grad()
def val(self, data_loader, **kwargs):
self.model.eval()
self.mode = 'val'
self.data_loader = data_loader
self.call_hook('before_val_epoch')
time.sleep(2) # Prevent possible deadlock during epoch transition
for i, data_batch in enumerate(self.data_loader):
self._inner_iter = i
self.call_hook('before_val_iter')
self.run_iter(data_batch, train_mode=False)
self.call_hook('after_val_iter')
self.call_hook('after_val_epoch')
上述逻辑是遍历 data_loader,然后进行 batch 级别的迭代训练或者验证,比较容易理解。真正完成一个 batch 的训练或者验证是调用了 self.run_iter
:
# 简化逻辑
def run_iter(self, data_batch, train_mode, **kwargs):
# 调用 model 自身的 train_step 或者 val_step 方法
if train_mode:
outputs = self.model.train_step(data_batch, self.optimizer,
**kwargs)
else:
outputs = self.model.val_step(data_batch, self.optimizer, **kwargs)
if 'log_vars' in outputs:
self.log_buffer.update(outputs['log_vars'], outputs['num_samples'])
self.outputs = outputs
(2) IterBasedRunner run
IterBasedRunner模式以迭代次数作为循环终止的条件, 没有 epoch 的概念,故 IterBasedRunner 的 run 方法有些许改动:
- 工作流终止条件不再是 epoch,而是 iter
- Hook 的生命周期方法也不涉及 epoch,全部是 iter 相关方法
由于MMDetection采用的EpochBasedRunner,而非IterBasedRunner,其详细的代码逻辑不再展开。
(3)EpochBasedRunner与IterBasedRunner比较文章来源:https://www.toymoban.com/news/detail-520586.html
假设数据长度是 1024,batch=4,那么 dataloader 长度是 1024/4=256, 也就是一个 epoch 是 256 次迭代,在 Iter 训练模式下,计划训练 100000 个迭代,若在Epoch训练模式下,那么实际上运行了 100000//256=39 个 epoch。文章来源地址https://www.toymoban.com/news/detail-520586.html
到了这里,关于MMDetection学习笔记(四):核心组件分析的文章就介绍完了。如果您还想了解更多内容,请在右上角搜索TOY模板网以前的文章或继续浏览下面的相关文章,希望大家以后多多支持TOY模板网!