PaddleSeg分割框架解读[01] 核心设计解析

这篇具有很好参考价值的文章主要介绍了PaddleSeg分割框架解读[01] 核心设计解析。希望对大家有所帮助。如果存在错误或未考虑完全的地方,请大家不吝赐教,您也可以点击"举报违法"按钮提交疑问。

PaddleSeg分割框架解读[01] 核心设计解析

tools/train.py

import argparse
import random
import numpy as np
import cv2

import paddle
from paddleseg.cvlibs import Config, SegBuilder
from paddleseg.utils import get_sys_env, logger, utils
from paddleseg.core import train


def parse_args():
    # 创建一个解析对象
    parser = argparse.ArgumentParser(description='Model training')
    # 添加要关注的命令行参数和选项
    # 训练和模型相关的配置文件
    parser.add_argument(
        "--config", 
        help="The config file.",
        type=str,
        default=None)
    # 训练设备
    parser.add_argument(
        '--device',
        help='Set the device place for training model.',
        type=str,
        default='gpu',
        choices=['cpu', 'gpu', 'xpu', 'npu', 'mlu'])
    # 训练权重保存路径
    parser.add_argument(
        '--save_dir',
        help='The directory for saving the model snapshot',
        type=str,
        default='./output')
    # 数据加载器的进程数num_workers
    parser.add_argument(
        '--num_workers',
        help='Num workers for data loader',
        type=int,
        default=0)
    # 是否边训练边验证
    parser.add_argument(
        '--do_eval',
        help='Eval while training',
        action='store_true')
    # 是否使用VisualDL进行可视化
    parser.add_argument(
        '--use_vdl',
        help='Whether to record the data to VisualDL during training',
        action='store_true')
    # 是否在训练中采用超参数进化
    parser.add_argument(
        '--use_ema',
        help='Whether to ema the model in training.',
        action='store_true')
    # 是否进行断点续练
    parser.add_argument(
        '--resume_model',
        help='The path of resume model',
        type=str,
        default=None)
    # 迭代次数
    parser.add_argument(
        '--iters',
        help='iters for training',
        type=int,
        default=None)
    # batch_size大小
    parser.add_argument(
        '--batch_size',
        help='Mini batch size of one gpu or cpu',
        type=int,
        default=None)
    # 初始学习率
    parser.add_argument(
        '--learning_rate',
        help='Learning rate',
        type=float,
        default=None)
    # 训练权重保存间隔
    parser.add_argument(
        '--save_interval',
        help='How many iters to save a model snapshot once during training.',
        type=int,
        default=1000)
    # 打印日志信息的间隔
    parser.add_argument(
        '--log_iters',
        help='Display logging information at every log_iters',
        type=int,
        default=10)
    # 最大保存的权重文件个数
    parser.add_argument(
        '--keep_checkpoint_max',
        help='Maximum number of checkpoints to save',
        type=int,
        default=20)
    # 随机种子设置
    parser.add_argument(
        '--seed',
        help='Set the random seed during training.',
        type=int,
        default=None)
    # 是否开启混合精度训练或正常训练
    parser.add_argument(
        "--precision",
        type=str,
        default="fp32",
        choices=["fp32", "fp16"],
        help="Use AMP (Auto mixed precision) if precision='fp16'. If precision='fp32', the training is normal."
    )
    # 自动混合精度水平
    parser.add_argument(
        "--amp_level",
        default="O1",
        type=str,
        choices=["O1", "O2"],
        help="Auto mixed precision level. Accepted values are “O1” and “O2”: O1 represent mixed precision, \
              the input data type of each operator will be casted by white_list and black_list; \
              O2 represent Pure fp16, all operators parameters and input data will be casted to fp16, \
              except operators in black_list, don’t support fp16 kernel and batchnorm. Default is O1(amp).")
    # 分析器的选择
    parser.add_argument(
        '--profiler_options',
        type=str,
        default=None,
        help='The option of train profiler. If profiler_options is not None, the train profiler is enabled' \
             'Refer to the paddleseg/utils/train_profiler.py for details.'
    )
    # 训练数据的格式"NCHW" or "NHWC"
    parser.add_argument(
        '--data_format',
        help='Data format that specifies the layout of input. It can be "NCHW" or "NHWC". Default: "NCHW".',
        type=str,
        default='NCHW')
    # 每个epoch中重复采样数据集的次数
    parser.add_argument(
        '--repeats',
        type=int,
        default=1,
        help="Repeat the samples in the dataset for `repeats` times in each epoch."
    )
    # nargs是用来说明传入的参数个数,'+' 表示传入至少一个参数。
    # 一种是定义nargs='?',可选项出现在命令行中,但之后并没有跟随赋值的参数,作为默认值传给此可选项。
    # 更新所有选项的键值对key-value
    parser.add_argument(
        '--opts', 
        help='Update the key-value pairs of all options.', 
        nargs='+')
    
    # 进行解析并获得传入的参数
    return parser.parse_args()


def main(args):
    # 必须指定配置文件config
    assert args.config is not None, 'No configuration file specified, please set --config'
    # 新的参数配置文件
    cfg = Config(
        args.config,
        learning_rate=args.learning_rate,
        iters=args.iters,
        batch_size=args.batch_size,
        opts=args.opts)
    builder = SegBuilder(cfg)
    
    utils.show_env_info()
    utils.show_cfg_info(cfg)
    utils.set_seed(args.seed)
    utils.set_device(args.device)
    utils.set_cv2_num_threads(args.num_workers)
    
    # 数据格式NHWC仅仅支持DeepLabv3+模型
    if args.data_format == 'NHWC':
        if cfg.dic['model']['type'] != 'DeepLabV3P':
            raise ValueError('The "NHWC" data format only support the DeepLabV3P model!')
        # 相关涉及到data_format的都需要进行修改
        cfg.dic['model']['data_format'] = args.data_format
        cfg.dic['model']['backbone']['data_format'] = args.data_format
        loss_len = len(cfg.dic['loss']['types'])
        for i in range(loss_len):
            cfg.dic['loss']['types'][i]['data_format'] = args.data_format
    
    model = utils.convert_sync_batchnorm(builder.model, args.device)
    
    # 训练数据集
    train_dataset = builder.train_dataset
    # 数据集重复次数
    if args.repeats > 1:
        train_dataset.file_list *= args.repeats
    # 验证数据集
    val_dataset = builder.val_dataset if args.do_eval else None
    # 优化器
    optimizer = builder.optimizer
    # 损失函数
    loss = builder.loss
    
    train(
        model=model,
        train_dataset=train_dataset,
        val_dataset=val_dataset,
        optimizer=optimizer,
        save_dir=args.save_dir,
        iters=cfg.iters,
        batch_size=cfg.batch_size,
        resume_model=args.resume_model,
        save_interval=args.save_interval,
        log_iters=args.log_iters,
        num_workers=args.num_workers,
        use_vdl=args.use_vdl,
        use_ema=args.use_ema,
        losses=loss,
        keep_checkpoint_max=args.keep_checkpoint_max,
        test_config=cfg.test_config,
        precision=args.precision,
        amp_level=args.amp_level,
        profiler_options=args.profiler_options,
        to_static_training=cfg.to_static_training)


if __name__ == '__main__':
    args = parse_args()
    main(args)

paddleseg/cvlibs/config.py

import six
import codecs
import os
from ast import literal_eval
from typing import Any, Dict, Optional
import yaml

import paddle
from paddleseg.cvlibs import config_checker as checker
from paddleseg.cvlibs import manager
from paddleseg.utils import logger, utils

_INHERIT_KEY = '_inherited_'
_BASE_KEY = '_base_'


class Config(object):
    """
    参数配置文件解析,仅仅支持yaml/yml文件。

    参数文件中的超参数hyper-parameters:
        batch_size: 每个gpu的样本数量。
        iters: 训练总共的迭代次数。
        train_dataset: 训练数据的配置,包括type/data_root/transforms/mode。
            For data type, please refer to paddleseg.datasets.(数据类型的参考)
            For specific transforms, please refer to paddleseg.transforms.transforms.(数据增强的参考)
        val_dataset: 验证数据的配置,包括type/data_root/transforms/mode。
        optimizer: 优化器的配置,请参考paddleseg.optimizers。
        learning_rate: 学习率的配置。 如果有衰减的设置,learning_rate值代表初始学习率,目前仅支持多项式衰减poly decay. 
                       decay power衰减率和end_lr最终学习率需要根据实验调整。
        loss: 损失函数的配置,多种损失函数Multi-loss。 
            损失函数类型的顺序必须和分割模型的输出一致,其中coef项表示相应损失的权重,注意coef的数量必须和模型输出的数量一样。 
            如果在输出中使用相同的损失类型,则可能只有一种损失类型,否则损失类型的数量必须与coef的数量一致。
        model: 模型的配置,包括type/backbone和model-dependent arguments.
            模型类型model,参考paddleseg.models。
            骨干网络类型backbone,参考paddleseg.models.backbones。

    Args:
        path (str) : config文件路径, 仅支持yaml格式。
        opts (list, optional): 使用opts去更新所有选项的键值对key-value。
        
    Examples:

        from paddleseg.cvlibs.config import Config

        # Create a cfg object with yaml file path.
        cfg = Config(yaml_cfg_path)

        # Parsing the argument when its property is used.
        train_dataset = cfg.train_dataset

        # the argument of model should be parsed after dataset,
        # since the model builder uses some properties in dataset.
        model = cfg.model
        ...
    """
    def __init__(
            self,
            path: str,
            learning_rate: Optional[float]=None,
            batch_size: Optional[int]=None,
            iters: Optional[int]=None,
            opts: Optional[list]=None,
            checker: Optional[checker.ConfigChecker]=None, ):
        assert os.path.exists(path), \
            'Config path ({}) does not exist'.format(path)
        assert path.endswith('yml') or path.endswith('yaml'), \
            'Config file ({}) should be yaml format'.format(path)

        # 将yaml文件解析成字典dict
        self.dic = self._parse_from_yaml(path)
        # 根据传参来进行配置文件yaml字典dict的更新
        self.dic = self.update_config_dict(
            self.dic,
            learning_rate=learning_rate,
            batch_size=batch_size,
            iters=iters,
            opts=opts)

        if checker is None:
            checker = self._build_default_checker()
        checker.apply_all_rules(self)

    # 使用@property装饰器来创建只读属性,@property装饰器会将方法转换为相同名称的只读属性,
    # 可以与所定义的属性配合使用,这样可以防止属性被修改。
    @property
    def batch_size(self) -> int:
        return self.dic.get('batch_size')

    @property
    def iters(self) -> int:
        return self.dic.get('iters')

    @property
    def to_static_training(self) -> bool:
        return self.dic.get('to_static_training', False)

    @property
    def model_cfg(self) -> Dict:
        return self.dic.get('model', {}).copy()

    @property
    def loss_cfg(self) -> Dict:
        return self.dic.get('loss', {}).copy()

    @property
    def distill_loss_cfg(self) -> Dict:
        return self.dic.get('distill_loss', {}).copy()

    @property
    def lr_scheduler_cfg(self) -> Dict:
        return self.dic.get('lr_scheduler', {}).copy()

    @property
    def optimizer_cfg(self) -> Dict:
        return self.dic.get('optimizer', {}).copy()

    @property
    def train_dataset_cfg(self) -> Dict:
        return self.dic.get('train_dataset', {}).copy()

    @property
    def val_dataset_cfg(self) -> Dict:
        return self.dic.get('val_dataset', {}).copy()

    # TODO merge test_config into val_dataset
    @property
    def test_config(self) -> Dict:
        return self.dic.get('test_config', {}).copy()

    @classmethod
    def update_config_dict(cls, dic: dict, *args, **kwargs) -> dict:
        return update_config_dict(dic, *args, **kwargs)

    # 在Python中,@classmethod装饰器用于将类中的方法声明为可以使用ClassName.MethodName()调用的类方法,
    # 也可以使用类的对象调用类方法。
    @classmethod
    def _parse_from_yaml(cls, path: str, *args, **kwargs) -> dict:
        return parse_from_yaml(path, *args, **kwargs)

    @classmethod
    def _build_default_checker(cls):
        rules = []
        rules.append(checker.DefaultPrimaryRule())
        rules.append(checker.DefaultSyncNumClassesRule())
        rules.append(checker.DefaultSyncImgChannelsRule())
        # Losses
        rules.append(checker.DefaultLossRule('loss'))
        rules.append(checker.DefaultSyncIgnoreIndexRule('loss'))
        # Distillation losses
        rules.append(checker.DefaultLossRule('distill_loss'))
        rules.append(checker.DefaultSyncIgnoreIndexRule('distill_loss'))

        return checker.ConfigChecker(rules, allow_update=True)

    def __str__(self) -> str:
        # Use NoAliasDumper to avoid yml anchor 
        return yaml.dump(self.dic, Dumper=utils.NoAliasDumper)


def parse_from_yaml(path: str):
    """
    递归地解析yaml文件并构建配置config
    """
    # 读取yaml文件,并转为字典dict形式
    with codecs.open(path, 'r', 'utf-8') as file:
        dic = yaml.load(file, Loader=yaml.FullLoader)

    if _BASE_KEY in dic:
        # pop()方法删除字典给定键key所对应的值,返回值为被删除的值。
        # 继承的基本路径
        base_files = dic.pop(_BASE_KEY)
        if isinstance(base_files, str):
            base_files = [base_files]
        # 对于继承的每一个base文件来说
        for bf in base_files:
            # os.path.dirname(path):去掉文件名,返回目录。
            base_path = os.path.join(os.path.dirname(path), bf)
            # 解析继承的_base_文件yaml
            base_dic = parse_from_yaml(base_path)
            # 更新字典
            dic = merge_config_dicts(dic, base_dic)
    return dic


def merge_config_dicts(dic, base_dic):
    """
    将dic合并到base_dic
    """
    base_dic = base_dic.copy() # 浅复制
    dic = dic.copy() # 浅复制

    # 判断dic是否继承
    if not dic.get(_INHERIT_KEY, True):
        dic.pop(_INHERIT_KEY)
        return dic

    # 循环遍历字典dic里面的key, value
    for key, val in dic.items():
        # 如果发现val是一个字典,并且key在base_dic里面,就继续更新。
        if isinstance(val, dict) and key in base_dic:
            base_dic[key] = merge_config_dicts(val, base_dic[key])
        else:
            base_dic[key] = val

    return base_dic


def update_config_dict(dic: dict,
                       learning_rate: Optional[float]=None,
                       batch_size: Optional[int]=None,
                       iters: Optional[int]=None,
                       opts: Optional[list]=None):
    """Update config"""
    # TODO: If the items to update are marked as anchors in the yaml file,
    # we should synchronize the references.
    dic = dic.copy()

    if learning_rate:
        dic['lr_scheduler']['learning_rate'] = learning_rate
    if batch_size:
        dic['batch_size'] = batch_size
    if iters:
        dic['iters'] = iters

    if opts is not None:
        for item in opts:
            assert ('=' in item) and (len(item.split('=')) == 2), "--opts params should be key=value," \
                " such as `--opts batch_size=1 test_config.scales=0.75,1.0,1.25`, " \
                "but got ({})".format(opts)

            key, value = item.split('=')
            if isinstance(value, six.string_types):
                try:
                    value = literal_eval(value)
                except ValueError:
                    pass
                except SyntaxError:
                    pass
            key_list = key.split('.')

            tmp_dic = dic
            for subkey in key_list[:-1]:
                assert subkey in tmp_dic, "Can not update {}, because it is not in config.".format(key)
                tmp_dic = tmp_dic[subkey]
            tmp_dic[key_list[-1]] = value

    return dic

paddleseg/cvlibs/builder.py

import copy
from typing import Any, Optional
import yaml
import paddle

from paddleseg.cvlibs import manager, Config
from paddleseg.utils import utils, logger
from paddleseg.utils.utils import CachedProperty as cached_property


class Builder(object):
    """
    用于生成组件的基类 

    Args:
        config (Config): Config类对象。
        comp_list (list, optional): 组件类的列表。Default: None
    """
    def __init__(self, config: Config, comp_list: Optional[list]=None):
        super().__init__()
        self.config = config
        self.comp_list = comp_list
    
    # {'type': 'MixedLoss', 'losses': [{'type': 'CrossEntropyLoss'}, {'type': 'LovaszSoftmaxLoss'}], 'coef': [0.4, 0.6]}

    def build_component(self, cfg):
        """
        Create Python object, such as model, loss, dataset, etc.
        """
        # copy.copy()是浅拷贝,只拷贝父对象,不会拷贝对象的内部的子对象。
        # copy.deepcopy()是深拷贝,会拷贝对象及其子对象,哪怕以后对其有改动,也不会影响其第一次的拷贝。
        cfg = copy.deepcopy(cfg)
        if 'type' not in cfg:
            raise RuntimeError(
                "It is not possible to create a component object from {}, as 'type' is not specified.".format(cfg)
                )
        # 类的类型
        class_type = cfg.pop('type')
        # 加载组件类
        com_class = self.load_component_class(class_type)
        # 参数字典
        params = {}
        for key, val in cfg.items():
            if self.is_meta_type(val):
                params[key] = self.build_component(val)
            elif isinstance(val, list):
                params[key] = [
                    self.build_component(item)
                    if self.is_meta_type(item) else item for item in val
                ]
            else:
                params[key] = val
                
        # 组件类的实例化
        try:
            obj = self.build_component_impl(com_class, **params)
        except Exception as e:
            if hasattr(com_class, '__name__'):
                com_name = com_class.__name__
            else:
                com_name = ''
            raise RuntimeError(
                f"Tried to create a {com_name} object, but the operation has failed. "
                "Please double check the arguments used to create the object.\n"
                f"The error message is: \n{str(e)}")

        return obj

    def build_component_impl(self, component_class, *args, **kwargs):
        return component_class(*args, **kwargs)

    def load_component_class(self, class_type):
        for com in self.comp_list:
            if class_type in com.components_dict:
                return com[class_type]
        raise RuntimeError("The specified component ({}) was not found.".format(class_type))

    @classmethod
    def is_meta_type(cls, obj):
        # TODO: should we define a protocol (see https://peps.python.org/pep-0544/#defining-a-protocol)
        # to make it more pythonic?
        return isinstance(obj, dict) and 'type' in obj

    @classmethod
    def show_msg(cls, name, cfg):
        msg = 'Use the following config to build {}\n'.format(name)
        msg += str(yaml.dump({name: cfg}, Dumper=utils.NoAliasDumper))
        logger.info(msg[0:-1])


class SegBuilder(Builder):
    """
    此类负责构建用于语义分割的部件。 
    """
    def __init__(self, config, comp_list=None):
        # 组件管理器列表
        if comp_list is None:
            comp_list = [
                manager.MODELS, manager.BACKBONES, manager.DATASETS,
                manager.TRANSFORMS, manager.LOSSES, manager.OPTIMIZERS
            ]
        super().__init__(config, comp_list)

    # @cached_property缓存装饰器,使用cached_property修饰过的函数,变成了对象的属性。
    # 当第一次引用该属性时,会调用该函数,以后再调用该属性时,会直接从字典中取。
    @cached_property
    def model(self) -> paddle.nn.Layer:
        model_cfg = self.config.model_cfg
        assert model_cfg != {}, 'No model specified in the configuration file.'

        if self.config.train_dataset_cfg['type'] != 'Dataset':
            # 检查并同步模型配置model config和数据集类dataset class中的num_classes
            assert hasattr(self.train_dataset_class, 'NUM_CLASSES'), \
                'If train_dataset class is not `Dataset`, it must have `NUM_CLASSES` attr.'
            num_classes = getattr(self.train_dataset_class, 'NUM_CLASSES')
            if 'num_classes' in model_cfg:
                assert model_cfg['num_classes'] == num_classes, \
                    'The num_classes is not consistent for model config ({}) ' \
                    'and train_dataset class ({}) '.format(model_cfg['num_classes'], num_classes)
            else:
                logger.warning(
                    'Add the `num_classes` in train_dataset class to model config.'
                    'We suggest you manually set `num_classes` in model config.'
                )
                model_cfg['num_classes'] = num_classes
            
            # 检查并同步模型配置model config和数据集类dataset class中的in_channels
            assert hasattr(self.train_dataset_class, 'IMG_CHANNELS'), \
                'If train_dataset class is not `Dataset`, it must have `IMG_CHANNELS` attr.'
            in_channels = getattr(self.train_dataset_class, 'IMG_CHANNELS')
            x = utils.get_in_channels(model_cfg)
            if x is not None:
                assert x == in_channels, \
                    'The in_channels in model config ({}) and the img_channels in train_dataset ' \
                    'class ({}) is not consistent'.format(x, in_channels)
            else:
                model_cfg = utils.set_in_channels(model_cfg, in_channels)
                logger.warning(
                    'Add the `in_channels` in train_dataset class to model config.'
                    'We suggest you manually set `in_channels` in model config.'
                )
        # 信息打印
        self.show_msg('model', model_cfg)
        return self.build_component(model_cfg)

    @cached_property
    def optimizer(self) -> paddle.optimizer.Optimizer:
        opt_cfg = self.config.optimizer_cfg
        assert opt_cfg != {}, 'No optimizer specified in the configuration file.'
        # For compatibility
        if opt_cfg['type'] == 'adam':
            opt_cfg['type'] = 'Adam'
        if opt_cfg['type'] == 'sgd':
            opt_cfg['type'] = 'SGD'
        if opt_cfg['type'] == 'SGD' and 'momentum' in opt_cfg:
            opt_cfg['type'] = 'Momentum'
            logger.info('If the type is SGD and momentum in optimizer config, '
                        'the type is changed to Momentum.')
        self.show_msg('optimizer', opt_cfg)
        opt = self.build_component(opt_cfg)
        opt = opt(self.model, self.lr_scheduler)
        return opt

    @cached_property
    def lr_scheduler(self) -> paddle.optimizer.lr.LRScheduler:
        lr_cfg = self.config.lr_scheduler_cfg
        assert lr_cfg != {}, 'No lr_scheduler specified in the configuration file.'

        use_warmup = False
        if 'warmup_iters' in lr_cfg:
            use_warmup = True
            warmup_iters = lr_cfg.pop('warmup_iters')
            assert 'warmup_start_lr' in lr_cfg, \
                "When use warmup, please set warmup_start_lr and warmup_iters in lr_scheduler"
            warmup_start_lr = lr_cfg.pop('warmup_start_lr')
            end_lr = lr_cfg['learning_rate']

        lr_type = lr_cfg.pop('type')
        if lr_type == 'PolynomialDecay':
            iters = self.config.iters - warmup_iters if use_warmup else self.config.iters
            iters = max(iters, 1)
            lr_cfg.setdefault('decay_steps', iters)

        try:
            lr_sche = getattr(paddle.optimizer.lr, lr_type)(**lr_cfg)
        except Exception as e:
            raise RuntimeError(
                "Create {} has failed. Please check lr_scheduler in config. "
                "The error message: {}".format(lr_type, e))

        if use_warmup:
            lr_sche = paddle.optimizer.lr.LinearWarmup(
                learning_rate=lr_sche,
                warmup_steps=warmup_iters,
                start_lr=warmup_start_lr,
                end_lr=end_lr)

        return lr_sche

    @cached_property
    def loss(self) -> dict:
        loss_cfg = self.config.loss_cfg
        assert loss_cfg != {}, 'No loss specified in the configuration file.'
        return self._build_loss('loss', loss_cfg)

    @cached_property
    def distill_loss(self) -> dict:
        loss_cfg = self.config.distill_loss_cfg
        assert loss_cfg != {}, 'No distill_loss specified in the configuration file.'
        return self._build_loss('distill_loss', loss_cfg)

    def _build_loss(self, loss_name, loss_cfg: dict):
        def _check_helper(loss_cfg, ignore_index):
            if 'ignore_index' not in loss_cfg:
                loss_cfg['ignore_index'] = ignore_index
                logger.warning('Add the `ignore_index` in train_dataset class to {} config.' \
                    'We suggest you manually set `ignore_index` in {} config.'.format(loss_name, loss_name)
                )
            else:
                assert loss_cfg['ignore_index'] == ignore_index, \
                    'the ignore_index in loss and train_dataset must be the same. Currently, loss ignore_index = {}, '\
                    'train_dataset ignore_index = {}'.format(loss_cfg['ignore_index'], ignore_index)

        # 检查并同步模型配置model config和数据集类dataset class中的ignore_index
        if self.config.train_dataset_cfg['type'] != 'Dataset':
            assert hasattr(self.train_dataset_class, 'IGNORE_INDEX'), \
                'If train_dataset class is not `Dataset`, it must have `IGNORE_INDEX` attr.'
            ignore_index = getattr(self.train_dataset_class, 'IGNORE_INDEX')
            for loss_cfg_i in loss_cfg['types']:
                if loss_cfg_i['type'] == 'MixedLoss':
                    # [{'type': 'CrossEntropyLoss'}, {'type': 'LovaszSoftmaxLoss'}]
                    for loss_cfg_j in loss_cfg_i['losses']:
                        _check_helper(loss_cfg_j, ignore_index)
                else:
                    _check_helper(loss_cfg_i, ignore_index)
        # 信息打印
        self.show_msg(loss_name, loss_cfg)
        loss_dict = {'coef': loss_cfg['coef'], "types": []}
        # {'type': 'MixedLoss', 'losses': [{'type': 'CrossEntropyLoss'}, {'type': 'LovaszSoftmaxLoss'}], 'coef': [0.4, 0.6]}
        for item in loss_cfg['types']:
            loss_dict['types'].append(self.build_component(item))
        
        return loss_dict

    @cached_property
    def train_dataset(self) -> paddle.io.Dataset:
        dataset_cfg = self.config.train_dataset_cfg
        assert dataset_cfg != {}, 'No train_dataset specified in the configuration file.'
        self.show_msg('train_dataset', dataset_cfg)
        dataset = self.build_component(dataset_cfg)
        assert len(dataset) != 0, \
            'The number of samples in train_dataset is 0. Please check whether the dataset is valid.'
        return dataset

    @cached_property
    def val_dataset(self) -> paddle.io.Dataset:
        dataset_cfg = self.config.val_dataset_cfg
        assert dataset_cfg != {}, 'No val_dataset specified in the configuration file.'
        self.show_msg('val_dataset', dataset_cfg)
        dataset = self.build_component(dataset_cfg)
        if len(dataset) == 0:
            logger.warning('The number of samples in val_dataset is 0. Please ensure this is the desired behavior.')
        return dataset

    @cached_property
    def train_dataset_class(self) -> Any:
        dataset_cfg = self.config.train_dataset_cfg
        assert dataset_cfg != {}, 'No train_dataset specified in the configuration file.'
        dataset_type = dataset_cfg.get('type')
        return self.load_component_class(dataset_type)

    @cached_property
    def val_dataset_class(self) -> Any:
        dataset_cfg = self.config.val_dataset_cfg
        assert dataset_cfg != {}, 'No val_dataset specified in the configuration file.'
        dataset_type = dataset_cfg.get('type')
        return self.load_component_class(dataset_type)

    @cached_property
    def val_transforms(self) -> list:
        dataset_cfg = self.config.val_dataset_cfg
        assert dataset_cfg != {}, 'No val_dataset specified in the configuration file.'
        transforms = []
        for item in dataset_cfg.get('transforms', []):
            transforms.append(self.build_component(item))
        return transforms

paddleseg/cvlibs/manager.py

特别注意,这块具体实现的类,如class Cityscapes(Dataset)等,称为组件;
组件管理器,则为相应的模型model管理器、数据集datasets管理器等。文章来源地址https://www.toymoban.com/news/detail-829382.html

import inspect
from collections.abc import Sequence

import warnings


class ComponentManager:
    """
    组件管理器类
    实现管理器类以正确添加新的组件,组件可以被添加作为类或函数类型。

    Args:
        name (str): The name of component.

    Returns:
        A callable object of ComponentManager.

    Examples 1:

        from paddleseg.cvlibs.manager import ComponentManager

        model_manager = ComponentManager()

        class AlexNet: ...
        class ResNet: ...

        model_manager.add_component(AlexNet)
        model_manager.add_component(ResNet)

        # Or pass a sequence alliteratively:
        model_manager.add_component([AlexNet, ResNet])
        print(model_manager.components_dict)
        # {'AlexNet': <class '__main__.AlexNet'>, 'ResNet': <class '__main__.ResNet'>}

    Examples 2:

        # Or an easier way, using it as a Python decorator, while just add it above the class declaration.
        from paddleseg.cvlibs.manager import ComponentManager

        model_manager = ComponentManager()

        @model_manager.add_component
        class AlexNet: ...

        @model_manager.add_component
        class ResNet: ...

        print(model_manager.components_dict)
        # {'AlexNet': <class '__main__.AlexNet'>, 'ResNet': <class '__main__.ResNet'>}
    """
    def __init__(self, name=None):
        self._components_dict = dict()
        self._name = name

    def __len__(self):
        return len(self._components_dict)

    def __repr__(self):
        name_str = self._name if self._name else self.__class__.__name__
        return "{}:{}".format(name_str, list(self._components_dict.keys()))

    def __getitem__(self, item):
        if item not in self._components_dict.keys():
            raise KeyError("{} does not exist in availabel {}".format(item, self))
        return self._components_dict[item]

    @property
    def components_dict(self):
        return self._components_dict

    @property
    def name(self):
        return self._name

    def _add_single_component(self, component):
        """
        将单个组件添加到相应的管理器中。(如,模型管理器)
        Args:
            component (function|class): A new component.
        Raises:
            TypeError: When `component` is neither class nor function.
            KeyError: When `component` was added already.
        """
        # 目前仅仅支持类class和函数function类型
        if not (inspect.isclass(component) or inspect.isfunction(component)):
            raise TypeError("Expect class/function type, but received {}".format(type(component)))

        # 获取组件的内部名称
        component_name = component.__name__

        # 检查这个组件是否已经被添加
        # 以组件的内部名称为键
        if component_name in self._components_dict.keys():
            warnings.warn("{} exists already! It is now updated to {} !!!".format(component_name, component))
            self._components_dict[component_name] = component
        else:
            self._components_dict[component_name] = component

    def add_component(self, components):
        """
        将组件添加到相应的管理器中。
        Args:
            components (function|class|list|tuple): Support four types of components.

        Returns:
            components (function|class|list|tuple): Same with input components.
        """
        # 判断这个组件components是否为序列
        if isinstance(components, Sequence):
            for component in components:
                self._add_single_component(component)
        else:
            component = components
            self._add_single_component(component)

        return components


# 模型model管理器
MODELS = ComponentManager("models")
# 骨干网络backbone管理器
BACKBONES = ComponentManager("backbones")
# 数据集datasets管理器
DATASETS = ComponentManager("datasets")
# 数据增强transforms管理器
TRANSFORMS = ComponentManager("transforms")
# 损失函数losses管理器
LOSSES = ComponentManager("losses")
# 优化器optimizers管理器
OPTIMIZERS = ComponentManager("optimizers")

到了这里,关于PaddleSeg分割框架解读[01] 核心设计解析的文章就介绍完了。如果您还想了解更多内容,请在右上角搜索TOY模板网以前的文章或继续浏览下面的相关文章,希望大家以后多多支持TOY模板网!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处: 如若内容造成侵权/违法违规/事实不符,请点击违法举报进行投诉反馈,一经查实,立即删除!

领支付宝红包 赞助服务器费用

相关文章

  • go web框架 gin-gonic源码解读01————Engine

    gin-gonic是go语言开发的轻量级web框架,性能优异,代码简洁,功能强大。有很多值得学习的地方,最近准备把这段时间学习gin的知识点,通过engine,context,router,middleware几篇博客文章总结总结。 而Engine是gin框架最核心的结构体。 为什么gin需要设计一个 Engine 结构体? 因为gi

    2024年02月14日
    浏览(29)
  • “分割一切”大模型SAM、超轻量PP-MobileSeg、工业质检工具、全景分割方案,PaddleSeg全新版本等你来体验!

    图像分割是计算机视觉的一项基础技术,其目标是将图像中的像素按内容分成不同的类别。它在许多领域有重要应用,比如自动驾驶、工业质检、医疗图像分析、遥感图像解译等。 PaddleSeg 是飞桨高性能图像分割开发套件 ,在图像分割领域做了大量的开源工作,致力于帮助企

    2023年04月19日
    浏览(39)
  • 【十七】【动态规划】DP41 【模板】01背包、416. 分割等和子集、494. 目标和,三道题目深度解析

    动态规划就像是解决问题的一种策略,它可以帮助我们更高效地找到问题的解决方案。这个策略的核心思想就是将问题分解为一系列的小问题,并将每个小问题的解保存起来。这样,当我们需要解决原始问题的时候,我们就可以直接利用已经计算好的小问题的解,而不需要重

    2024年02月03日
    浏览(32)
  • RAG应用开发实战(01)-RAG应用框架和解析器

    第三方的工具去对文件解析拆分,去将我们的文件内容给提取出来,并将我们的文档内容去拆分成一个小的chunk。常见的PDF word mark down, JSON、HTML。都可以有很好的一些模块去把这些文件去进行一个东西去提取。 支持丰富的文档类型 每种文档多样化选择 与开源框架无缝集成

    2024年04月11日
    浏览(38)
  • 框架解读 | Retrofit设计剖析

    作者:Calculus_小王 Retrofit是一个类型安全的HTTP客户端,可以通过注解将HTTP API转换为Java接口,并使用动态代理,CallAdapter和Converter来发起请求和解析响应。 本文 着重于 Retrofit的架构设计,对于其 注解解析能力 上 不作详细阐述 本文基于 retrofit:2.6.2 本示例仅以最基础的retro

    2024年02月13日
    浏览(40)
  • 区块链钱包开发(Android篇),深入解析android核心组件和应用框架

    作用: 1、备份更容易。按照比特币的原则,尽量不要使用同一个地址,一个地址只使用一次,这样会导致频繁备份钱包。HD钱包只需要在创建时保存主密钥,通过主密钥可以派生出所有的子密钥。 2、私钥离线更安全。主私钥离线存储,主公钥在线使用,通过主公钥可以派生

    2024年03月24日
    浏览(37)
  • 领域驱动设计实践框架-COLA的解读

            Cola作为当前比较优秀的领域驱动设计最佳实践框架越来越被更多的技术人所知晓。先抛出COLA 4.0:应用架构的最佳实践_张建飞(Frank)的博客-CSDN博客_cola架构 是关于COLA4.0最新的内容介绍。然后个人对于读了这篇文章后,对于其中的架构理念和其中的各组件的设计加

    2024年02月03日
    浏览(25)
  • Spring框架核心与设计思想

    我们一般所说的Spring指的是Spring Framework(Spring 框架),它是一个开源的框架,Spring支持广泛的应用场景,它可以让Java企业级的应用程序开发变得更简单,官方一点的回答:spring是J2EE应用程序框架,是轻量级的IoC和AOP的容器框架,主要是针对javaBean的生命周期进行管理的轻量级

    2023年04月15日
    浏览(33)
  • 《Vue.js 设计与实现》—— 02 框架设计核心要素

    框架设计并非仅仅实现功能那么简单,里面有很多学问。例如: 框架应该给用户提供哪些构建产物?产物的模块格式如何? 当用户没有以预期的方式使用框架时,是否应该打印合适的警告信息从而提供更好的开发体验,让用户快速定位问题? 开发版本和生产版本的构建有何

    2024年02月03日
    浏览(36)
  • Spring框架概述及核心设计思想

    我们通常所说的 Spring 指的是 Spring Framework(Spring 框架),它是⼀个开源框架,有着活跃而庞大的社区,这就是它之所以能长久不衰的原因;Spring 支持广泛的应用场景,它可以让 Java 企业级的应用程序开发起来更简单。 用⼀句话概括 Spring: Spring 框架是包含了众多工具方法的

    2024年02月16日
    浏览(27)

觉得文章有用就打赏一下文章作者

支付宝扫一扫打赏

博客赞助

微信扫一扫打赏

请作者喝杯咖啡吧~博客赞助

支付宝扫一扫领取红包,优惠每天领

二维码1

领取红包

二维码2

领红包