【mmaction2 入门教程 04】训练 AVA 数据集中的自定义类别

这篇具有很好参考价值的文章主要介绍了【mmaction2 入门教程 04】训练 AVA 数据集中的自定义类别。希望对大家有所帮助。如果存在错误或未考虑完全的地方,请大家不吝赐教,您也可以点击"举报违法"按钮提交疑问。

0 前言

在时空行为数据集中,最常出现的就是长尾数据集,即某些类别的动作标签过少,导致训练效果不好,在mmation2当中,提供了一个方法,就是可以自定义要训练的类别。

那么先看看我之前训练的分析结果
【mmaction2 入门教程 04】训练 AVA 数据集中的自定义类别
上图中,深蓝色的样本数量,浅蓝色的ap值,可以看出,样本极少的类别,ap值几乎为0,那么我们在训练的时候,可以忽略这些行为。

那如何来做呢,在mmaction2文档中已经说明了:https://mmaction2.readthedocs.io/zh_CN/latest/detection_models.html#id11

GPU平台:https://cloud.videojj.com/auth/register?inviter=18452&activityChannel=student_invite

b站:https://www.bilibili.com/video/BV1RV4y1s7mK/

1(官网)训练 AVA 数据集中的自定义类别

用户可以训练 AVA 数据集中的自定义类别。AVA 中不同类别的样本量很不平衡:其中有超过 100000 样本的类别: stand/listen to (a person)/talk to (e.g., self, a person, a group)/watch (a person),也有样本较少的类别(半数类别不足 500 样本)。大多数情况下,仅使用样本较少的类别进行训练将在这些类别上得到更好精度。

训练 AVA 数据集中的自定义类别包含 3 个步骤:

  1. 从原先的类别中选择希望训练的类别,将其填写至配置文件的 custom_classes 域中。其中 0 不表示具体的动作类别,不应被选择。
  2. num_classes 设置为 num_classes = len(custom_classes) + 1
    • 在新的类别到编号的对应中,编号 0 仍对应原类别 0,编号 i (i > 0) 对应原类别 custom_classes[i-1]
    • 配置文件中 3 处涉及 num_classes 需要修改:model -> roi_head -> bbox_head -> num_classes, data -> train -> num_classes, data -> val -> num_classes.
    • num_classes <= 5, 配置文件 BBoxHeadAVA 中的 topk 参数应被修改。topk 的默认值为 (3, 5)topk 中的所有元素应小于 num_classes
  3. 确认所有自定义类别在 label_file 中。

2 训练数据集中的自定义类别

那么就开始咯,通过统计分析,原本71个动作类,选取了其中21个动作类。

2.1 配置文件

直接上配置文件:

# model setting
custom_classes = [3, 8, 11, 12, 13, 22, 23, 29, 32, 33, 36, 42, 45, 46, 49, 52, 55, 56, 59, 60, 67]
#custom_classes = [2, 7, 10, 11, 12, 21, 22, 28, 31, 32, 35, 41, 44, 45, 48, 51, 54, 55, 58, 59, 66]
num_classes = len(custom_classes) + 1
model = dict(
    type='FastRCNN',
    backbone=dict(
        type='ResNet3dSlowFast',
        pretrained=None,
        resample_rate=8,
        speed_ratio=8,
        channel_ratio=8,
        slow_pathway=dict(
            type='resnet3d',
            depth=50,
            pretrained=None,
            lateral=True,
            conv1_kernel=(1, 7, 7),
            dilations=(1, 1, 1, 1),
            conv1_stride_t=1,
            pool1_stride_t=1,
            inflate=(0, 0, 1, 1),
            spatial_strides=(1, 2, 2, 1)),
        fast_pathway=dict(
            type='resnet3d',
            depth=50,
            pretrained=None,
            lateral=False,
            base_channels=8,
            conv1_kernel=(5, 7, 7),
            conv1_stride_t=1,
            pool1_stride_t=1,
            spatial_strides=(1, 2, 2, 1))),
    roi_head=dict(
        type='AVARoIHead',
        bbox_roi_extractor=dict(
            type='SingleRoIExtractor3D',
            roi_layer_type='RoIAlign',
            output_size=8,
            with_temporal_pool=True),
        bbox_head=dict(
            type='BBoxHeadAVA',
            in_channels=2304,
            #num_classes=81,
            num_classes=num_classes,
            multilabel=True,
            dropout_ratio=0.5)),
    train_cfg=dict(
        rcnn=dict(
            assigner=dict(
                type='MaxIoUAssignerAVA',
                pos_iou_thr=0.9,
                neg_iou_thr=0.9,
                min_pos_iou=0.9),
            sampler=dict(
                type='RandomSampler',
                num=32,
                pos_fraction=1,
                neg_pos_ub=-1,
                add_gt_as_proposals=True),
            pos_weight=1.0,
            debug=False)),
    test_cfg=dict(rcnn=dict(action_thr=0.002)))

dataset_type = 'AVADataset'
data_root = '/home/MPCLST/Dataset/rawframes'
anno_root = '/home/MPCLST/Dataset/annotations'


#ann_file_train = f'{anno_root}/ava_train_v2.1.csv'
ann_file_train = f'{anno_root}/train.csv'
#ann_file_val = f'{anno_root}/ava_val_v2.1.csv'
ann_file_val = f'{anno_root}/val.csv'

#exclude_file_train = f'{anno_root}/ava_train_excluded_timestamps_v2.1.csv'
#exclude_file_val = f'{anno_root}/ava_val_excluded_timestamps_v2.1.csv'

exclude_file_train = f'{anno_root}/train_excluded_timestamps.csv'
exclude_file_val = f'{anno_root}/val_excluded_timestamps.csv'

#label_file = f'{anno_root}/ava_action_list_v2.1_for_activitynet_2018.pbtxt'
label_file = f'{anno_root}/action_list.pbtxt'

proposal_file_train = (f'{anno_root}/dense_proposals_train.pkl')
proposal_file_val = f'{anno_root}/dense_proposals_val.pkl'

img_norm_cfg = dict(
    mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_bgr=False)

train_pipeline = [
    dict(type='SampleAVAFrames', clip_len=32, frame_interval=2),
    dict(type='RawFrameDecode'),
    dict(type='RandomRescale', scale_range=(256, 320)),
    dict(type='RandomCrop', size=256),
    dict(type='Flip', flip_ratio=0.5),
    dict(type='Normalize', **img_norm_cfg),
    dict(type='FormatShape', input_format='NCTHW', collapse=True),
    # Rename is needed to use mmdet detectors
    dict(type='Rename', mapping=dict(imgs='img')),
    dict(type='ToTensor', keys=['img', 'proposals', 'gt_bboxes', 'gt_labels']),
    dict(
        type='ToDataContainer',
        fields=[
            dict(key=['proposals', 'gt_bboxes', 'gt_labels'], stack=False)
        ]),
    dict(
        type='Collect',
        keys=['img', 'proposals', 'gt_bboxes', 'gt_labels'],
        meta_keys=['scores', 'entity_ids'])
]
# The testing is w/o. any cropping / flipping
val_pipeline = [
    dict(type='SampleAVAFrames', clip_len=32, frame_interval=2),
    dict(type='RawFrameDecode'),
    dict(type='Resize', scale=(-1, 256)),
    dict(type='Normalize', **img_norm_cfg),
    dict(type='FormatShape', input_format='NCTHW', collapse=True),
    # Rename is needed to use mmdet detectors
    dict(type='Rename', mapping=dict(imgs='img')),
    dict(type='ToTensor', keys=['img', 'proposals']),
    dict(type='ToDataContainer', fields=[dict(key='proposals', stack=False)]),
    dict(
        type='Collect',
        keys=['img', 'proposals'],
        meta_keys=['scores', 'img_shape'],
        nested=True)
]

data = dict(
    #videos_per_gpu=9,
    #workers_per_gpu=2,
    videos_per_gpu=5,
    workers_per_gpu=2,
    val_dataloader=dict(videos_per_gpu=1),
    test_dataloader=dict(videos_per_gpu=1),
    train=dict(
        type=dataset_type,
        ann_file=ann_file_train,
        exclude_file=exclude_file_train,
        pipeline=train_pipeline,
        label_file=label_file,
        proposal_file=proposal_file_train,
        person_det_score_thr=0.9,
        data_prefix=data_root,
        num_classes=num_classes,
        custom_classes=custom_classes,
        start_index=1,),
    val=dict(
        type=dataset_type,
        ann_file=ann_file_val,
        exclude_file=exclude_file_val,
        pipeline=val_pipeline,
        label_file=label_file,
        proposal_file=proposal_file_val,
        person_det_score_thr=0.9,
        data_prefix=data_root,
        num_classes=num_classes,
        custom_classes=custom_classes,
        start_index=1,))
data['test'] = data['val']

#optimizer = dict(type='SGD', lr=0.1125, momentum=0.9, weight_decay=0.00001)
optimizer = dict(type='SGD', lr=0.0125, momentum=0.9, weight_decay=0.00001)
# this lr is used for 8 gpus

optimizer_config = dict(grad_clip=dict(max_norm=40, norm_type=2))
# learning policy

lr_config = dict(
    policy='step',
    step=[10, 15],
    warmup='linear',
    warmup_by_epoch=True,
    warmup_iters=5,
    warmup_ratio=0.1)
#total_epochs = 20
total_epochs = 35
checkpoint_config = dict(interval=1)
workflow = [('train', 1)]
evaluation = dict(interval=1, save_best='mAP@0.5IOU')
log_config = dict(
    interval=20, hooks=[
        dict(type='TextLoggerHook'),
    ])
dist_params = dict(backend='nccl')
log_level = 'INFO'
work_dir = ('./work_dirs/22-8-15-custom-ava/'
            'slowfast_kinetics_pretrained_r50_4x16x1_20e_ava_rgb')
load_from = ('https://download.openmmlab.com/mmaction/recognition/slowfast/'
             'slowfast_r50_4x16x1_256e_kinetics400_rgb/'
             'slowfast_r50_4x16x1_256e_kinetics400_rgb_20200704-bcde7ed7.pth')
resume_from = None
find_unused_parameters = False

【mmaction2 入门教程 04】训练 AVA 数据集中的自定义类别
【mmaction2 入门教程 04】训练 AVA 数据集中的自定义类别
【mmaction2 入门教程 04】训练 AVA 数据集中的自定义类别

2.2 执行训练

#训练
cd /home/MPCLST/mmaction2_YF
python tools/train.py configs/detection/ava/my_custom_slowfast_kinetics_pretrained_r50_4x16x1_20e_ava.py --validate

2.3 可视化测试

在上一篇博客写过:【mmaction2 入门教程 03】评价指标可视化 mAP、每类行为的ap值、每类行为的数量

# 测试集的分析-latest
cd /home/MPCLST/mmaction2_YF/   
python tools/test.py configs/detection/ava/my_custom_slowfast_kinetics_pretrained_r50_4x16x1_20e_ava.py ./work_dirs/22-8-15-custom-ava/slowfast_kinetics_pretrained_r50_4x16x1_20e_ava_rgb/latest.pth --eval mAP

【mmaction2 入门教程 04】训练 AVA 数据集中的自定义类别

# 训练集的分析-latest
cd /home/MPCLST/mmaction2_YF/   
python tools/test.py configs/detection/ava/my_custom_slowfast_kinetics_pretrained_r50_4x16x1_20e_ava2.py ./work_dirs/22-8-15-custom-ava/slowfast_kinetics_pretrained_r50_4x16x1_20e_ava_rgb/latest.pth --eval mAP

【mmaction2 入门教程 04】训练 AVA 数据集中的自定义类别文章来源地址https://www.toymoban.com/news/detail-432508.html

到了这里,关于【mmaction2 入门教程 04】训练 AVA 数据集中的自定义类别的文章就介绍完了。如果您还想了解更多内容,请在右上角搜索TOY模板网以前的文章或继续浏览下面的相关文章,希望大家以后多多支持TOY模板网!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处: 如若内容造成侵权/违法违规/事实不符,请点击违法举报进行投诉反馈,一经查实,立即删除!

领支付宝红包 赞助服务器费用

相关文章

  • HC-SR04超级简单教程(快速入门)

    目录 一、模块介绍(个人理解)         1.简单理解         2.该模块的参数 二、HC-SR04的操作 三、代码         1.代码前的注意事项         2.关键代码 四、代码实战效果图  五、结束         HC-SR04是一个超声波测距模块,通过发出超声波然后接收超声波

    2024年02月16日
    浏览(50)
  • 551、Elasticsearch详细入门教程系列 -【分布式全文搜索引擎 Elasticsearch(二)】 2023.04.04

    1.1 Elasticsearch中的数据格式 Elasticsearch 是面向文档型数据库,一条数据在这里就是一个文档。为了方便大家理解,我们将 Elasticsearch 里存储文档数据和关系型数据库 MySQL 存储数据的概念进行一个类比。 ES 里的 Index 可以看做一个库,而 Types 相当于表,Documents 则相当于表的行。

    2023年04月11日
    浏览(84)
  • 【AI绘画】《超入门级教程:训练自己的LORA模型》,MM超爱的萌宠图片实战

    SD-Trainer:是stable diffusion进行lora训练的webui,有了SD-Trainer,只需要少许图片,每个人都能够方便快捷地训练出属于自 己的stable diffusion模型,可以让图片按照你的想法进行呈现。 SD-Trainer :是stable diffusion进行lora训练的webui,有了SD-Trainer,只需要少许图片,每个人都能够方便快

    2024年02月14日
    浏览(46)
  • 【yolov5 安装教程】(入门篇)避免踩雷保姆级教程 在m1芯片下 使用yolov5本地训练自己的数据集 ——mac m1

    ​​​​​​​ 目录 一、简介 配置 环境准备 二、环境配置 1.安装anaconda 2.安装TensorFlow 3.安装pytorch 4.pyqt5安装  5.安装labelimg 6.下载yolov5 7.pycharm安装 三、使用labelimg标记图片 1.准备工作 2.标记图片 四、 划分数据集以及配置文件修改 1. 划分训练集、验证集、测试集 2.XML格式转

    2024年02月05日
    浏览(53)
  • Django框架入门到精通(04)Django创建第一个项目 (黄菊华老师大学生毕业设计学习教程)

    博主介绍: 《Vue.js入门与商城开发实战》《微信小程序商城开发》图书作者,CSDN博客专家,在线教育专家,CSDN钻石讲师;专注大学生毕业设计教育和辅导。 所有项目都配有从入门到精通的基础知识视频课程,免费 项目配有对应开发文档、开题报告、任务书、PPT、论文模版

    2024年02月06日
    浏览(53)
  • win11系统AVA2.1数据集制作、训练、测试、本地视频验证(完整已跑通)

    本文参照杨帆老师的博客,根据自己的需要进行制作,杨帆老师博客原文链接如下: 自定义ava数据集及训练与测试 完整版 时空动作/行为 视频数据集制作 yolov5, deep sort, VIA MMAction, SlowFast-CSDN博客 文章浏览阅读2.2w次,点赞31次,收藏165次。前言这一篇博客应该是我花时间最多

    2024年02月19日
    浏览(54)
  • Yalmip入门教程(1)-入门学习

            博客中所有内容均来源于自己学习过程中积累的经验以及对yalmip官方文档的翻译:YALMIP         Yalmip的作者是Johan Löfberg,是由Matlab平台编程实现的一个免费开源数学优化工具箱,在官网上就可以下载。官方下载链接如下: Download - YALMIP         下载时可以选

    2024年02月15日
    浏览(51)
  • 瑞萨MCU入门教程(非常详细的瑞萨单片机入门教程)

    得益于瑞萨强大的MCU、强大的软件开发工具(e² studio),也得益于瑞萨和RA生态工作室提供的支持,我们团队编写了《ARM嵌入式系统中面向对象的模块编程方法》,全书37章,将近500页: 讲解面向对象编程在单片机开发中的使用 结合FSP软件包实例分析外设驱动 讲解如何使用RASC配

    2024年02月08日
    浏览(47)
  • MaterialDesignInXAML WPF入门教程 快速入门

    先去MaterialDesignInXAML下载下来源码,以及Releases,在DemoApp 中就可以看到实际的效果很惊艳了。 除了要有一定的C#、winform 基础外,建议先学习一下 XAML,对整个开发环境有个基础的了解,再来学习此教程。 可以去bilibili上免费学习一下。教程一共12个小时,如果不看后面的实战

    2024年02月05日
    浏览(56)
  • 爬虫教程1_Xpath 入门教程

    在编写爬虫程序的过程中提取信息是非常重要的环节,但是有时使用正则表达式无法匹配到想要的信息,或者书写起来非常麻烦,此时就需要用另外一种数据解析方法,也就是本节要介绍的 Xpath 表达式。 XPath(全称:XML Path Language)即 XML 路径语言,它是一门在 XML 文档中查找

    2024年02月14日
    浏览(35)

觉得文章有用就打赏一下文章作者

支付宝扫一扫打赏

博客赞助

微信扫一扫打赏

请作者喝杯咖啡吧~博客赞助

支付宝扫一扫领取红包,优惠每天领

二维码1

领取红包

二维码2

领红包