【MMDetection】——训练个人数据集

这篇具有很好参考价值的文章主要介绍了【MMDetection】——训练个人数据集。希望对大家有所帮助。如果存在错误或未考虑完全的地方,请大家不吝赐教,您也可以点击"举报违法"按钮提交疑问。

1、数据集格式及存放

mmdet支持COCO格式和VOC格式,能用COCO格式,还是建议COCO的。网上有YOLO转COCO,VOC转COCO,可以自己转换。

在mmdetection代码的根目录下,创建 data/coco 文件夹,按照coco的格式排放好数据集。annotations下面是标签文件,train2017val2017test2017是图片。
【MMDetection】——训练个人数据集
【MMDetection】——训练个人数据集

2、修改两处

第一处: mmdet/core/evalution/class_names.py 代码下的 def coco_classes() 的 return 内容改为自己数据集的类别;

第二处:mmdet/datasets/coco.py 代码下的 class CocoDataset(CustomDataset) 的 CLASSES 改为自己数据集的类别;

注意:修改两处后,一定要在根目录下,输入命令:
python setup.py install build
重新编译代码,要不然类别会没有载入,还是原coco类别,训练异常。

3、用训练命令生成配置文件

python tools/train.py configs/cascade_rcnn/cascade_rcnn_r50_fpn_1x_coco.py --work-dir work_dirs

其中,work_dirs是自己在根目录新建的工作目录,训练文件存储在这里。

注意,此时运行命令之后,并不是直接训练就可以不管了!我们还有参数设置没改!这里输入训练命令,只是需要它生成一个配置文件,便于我们改参数!【MMDetection】——训练个人数据集

打开配置文件 cascade_rcnn_r50_fpn_1x_coco.py :
(1)修改 num_classes ,将其改为自己数据类别(直接全局搜索,有3处,都要改);

(2)修改 data_root 路径和训练集、验证集、测试集的图片和标签路径,如下图:
【MMDetection】——训练个人数据集

【MMDetection】——训练个人数据集
【MMDetection】——训练个人数据集

(3)修改训练图片大小和学习率

修改下处代码,可以更改图片大小

img_scale = (1333, 800), 

batch_size, mmdet默认的方式是由 GPU 数量与 samples_per_gpu 参数决定:
samples_per_gpu: 每个gpu读取的图像数量(意思不就是batch_size=2),该参数和训练时的gpu数量决定了训练时的batch_size。(为什么这么说呢?因为mmdet是8个GPU训练的,那么总的batch就是 8 *samples_per_gpu=16,即训练时是batch_size为16) 。
但我们通常是只有一个gpu, 该参数设置为 2, 意思就是我们训练的 batch_size为2;

workers_per_gpu: 读取数据时每个gpu分配的线程数 ,一般设置为 2即可;(我感觉既然用单个GPU,设置到8也无妨吧?我还没试)

【MMDetection】——训练个人数据集

学习率设置:
mmdet 默认的学习率是基于8个gpu,而且默认是1个GPU处理2个图像(就上面说的samples_per_gpu为2),可以这样理解:
8个GPU,每个GPU处理2张图片,那么真实训练总的一个batch就包括16张图片,学习率为0.02;
4个GPU,每个GPU处理2张图片,那么真实训练总的一个batch就包括8张图片,学习率为0.01;
1个GPU,每个GPU处理2张图片,那么真实训练总的一个batch就包括2张图片,学习率为0.0025;
1个GPU,每个GPU处理1张图片,那么真实训练总的一个batch就包括1张图片,学习率为0.00125;
【MMDetection】——训练个人数据集
(4)使用预训练模型
提前从github上下载预训练模型,新建一个checkpoints文件夹下,放到里面。(模型下载链接:https://github.com/open-mmlab/mmdetection/blob/master/docs/en/model_zoo.md)
然后修改以下代码:

# 原本是 load_from = None ,修改为
load_from = 'checkpoints/fcascade_rcnn_r50_fpn_1x_coco_20200316-3dc56deb.pth’

(5)训练轮数,保存模型间隔,日志保存参数
【MMDetection】——训练个人数据集

4、正式训练开始

!!!看清楚路径!使用的是更改过的配置文件训练!!!

python tools/train.py work_dirs/cascade_rcnn_r50_fpn_1x_coco.py

5、报错记录

在第三步生成配置文件时,遇到以下报错:

AssertionError: The num_classes (10) in Shared2FCBBoxHead of
MMDataParallel does not matches the length of CLASSES 80) in
CocoDataset

即使在修改 coco.py 和 class_names.py 后运行 python setup.py install仍然无法解决;

解决方法:
根据报错信息,找到自己虚拟环境的/mmdet/datasets/coco.pymmdet/core/evaluation/class_names.py,再次修改
CocoDataset()coco_classes()l两处(跟第二步一样,其实打开,就能看到虚拟环境下的并没有修改成功)

参考链接:AssertionError: The num_classes (3) in Shared2FCBBoxHead of
MMDataParallel does not
matches

6、模型评价测试(VOC指标mAP、COCO指标AP)

(1)生成中间件

python tools/test.py work_dirs/cascade_rcnn_r50_fpn_1x_coco.py work_dirs/epoch_20.pth  --out results.pkl
  • work_dirs/cascade_rcnn_r50_fpn_1x_coco.py 模型配置文件(跟训练时的一样)
  • work_dirs/epoch_20.pth: 训练好的模型(我是训练了20epoch)
  • --out 指定 results.pkl 输出目录,可以自己指定输出目录

(2)使用COCO标准评估指标

python tools/analysis_tools/eval_metric.py  work_dirs/cascade_rcnn_r50_fpn_1x_coco.py results.pkl  --eval=bbox

  • --eval,COCO数据集可选参数有:bbox 、segm、proposal ;对VOC数据集可选参数有:mAP

(3)使用VOC标准评估指标

# results.pkl 的顺序别放错,在中间。
python tools/voc_eval.py results.pkl work_dirs/cascade_rcnn_r50_fpn_1x_coco.py  
  • voc_eval.py 文件 mmdetection 2.X 版本删除了,可以去老版本1.X 找找

7、绘制每个类别bbox 的结果曲线图并保存

(1)使用 test.py 生成 results.bbox.json 文件(在根目录下,路径可自己指定)

python tools/test.py  work_dirs/cascade_rcnn_r50_fpn_1x_coco.py work_dirs/epoch_20.pth  --format-only  --options "jsonfile_prefix=./results"

(2)获得COCO bbox错误结果每个类别,保存分析结果图像到目录results/

python tools/analysis_tools/coco_error_analysis.py results.bbox.json results  --ann=data/coco/annotations/instances_val2017.json
  • results.bbox.json:上一步生成的文件
  • results: 结果曲线图的生成目录, 此处将生成到results/ 目录下
  • –ann=data/coco/annotations/instances_val2017.json: 数据集标注文件存放路径

8、统计模型参数量和FLOPs

python tools/analysis_tools/get_flops.py work_dirs/cascade_rcnn_r50_fpn_1x_coco.py --shape 640 640
  • --shape 参数指定输入图片尺寸

9 计算混淆矩阵

python tools/analysis_tools/confusion_matrix.py work_dirs/cascade_rcnn_r50_fpn_1x_coco.py results.pkl coco_confusion_matrix/
  • 需要三个参数,配置文件、pkl文件、输出目录

10 画PR曲线

plot_pr_curve.py 代码来自:https://blog.csdn.net/weixin_44966641/article/details/124558532

import os
import sys
import mmcv
import numpy as np
import argparse
import matplotlib.pyplot as plt

from pycocotools.coco import COCO
from pycocotools.cocoeval import COCOeval

from mmcv import Config
from mmdet.datasets import build_dataset


def plot_pr_curve(config_file, result_file, out_pic, metric="bbox"):
    """plot precison-recall curve based on testing results of pkl file.

        Args:
            config_file (list[list | tuple]): config file path.
            result_file (str): pkl file of testing results path.
            metric (str): Metrics to be evaluated. Options are
                'bbox', 'segm'.
    """
    
    cfg = Config.fromfile(config_file)
    # turn on test mode of dataset
    if isinstance(cfg.data.test, dict):
        cfg.data.test.test_mode = True
    elif isinstance(cfg.data.test, list):
        for ds_cfg in cfg.data.test:
            ds_cfg.test_mode = True

    # build dataset
    dataset = build_dataset(cfg.data.test)
    # load result file in pkl format
    pkl_results = mmcv.load(result_file)
    # convert pkl file (list[list | tuple | ndarray]) to json
    json_results, _ = dataset.format_results(pkl_results)
    # initialize COCO instance
    coco = COCO(annotation_file=cfg.data.test.ann_file)
    coco_gt = coco
    coco_dt = coco_gt.loadRes(json_results[metric]) 
    # initialize COCOeval instance
    coco_eval = COCOeval(coco_gt, coco_dt, metric)
    coco_eval.evaluate()
    coco_eval.accumulate()
    coco_eval.summarize()
    # extract eval data
    precisions = coco_eval.eval["precision"]
    '''
    precisions[T, R, K, A, M]
    T: iou thresholds [0.5 : 0.05 : 0.95], idx from 0 to 9
    R: recall thresholds [0 : 0.01 : 1], idx from 0 to 100
    K: category, idx from 0 to ...
    A: area range, (all, small, medium, large), idx from 0 to 3
    M: max dets, (1, 10, 100), idx from 0 to 2
    '''
    pr_array1 = precisions[0, :, 0, 0, 2] 
    pr_array2 = precisions[1, :, 0, 0, 2] 
    pr_array3 = precisions[2, :, 0, 0, 2] 
    pr_array4 = precisions[3, :, 0, 0, 2] 
    pr_array5 = precisions[4, :, 0, 0, 2] 
    pr_array6 = precisions[5, :, 0, 0, 2] 
    pr_array7 = precisions[6, :, 0, 0, 2] 
    pr_array8 = precisions[7, :, 0, 0, 2] 
    pr_array9 = precisions[8, :, 0, 0, 2] 
    pr_array10 = precisions[9, :, 0, 0, 2] 

    x = np.arange(0.0, 1.01, 0.01)
    # plot PR curve
    plt.plot(x, pr_array1, label="iou=0.5")
    plt.plot(x, pr_array2, label="iou=0.55")
    plt.plot(x, pr_array3, label="iou=0.6")
    plt.plot(x, pr_array4, label="iou=0.65")
    plt.plot(x, pr_array5, label="iou=0.7")
    plt.plot(x, pr_array6, label="iou=0.75")
    plt.plot(x, pr_array7, label="iou=0.8")
    plt.plot(x, pr_array8, label="iou=0.85")
    plt.plot(x, pr_array9, label="iou=0.9")
    plt.plot(x, pr_array10, label="iou=0.95")

    plt.xlabel("recall")
    plt.ylabel("precison")
    plt.xlim(0, 1.0)
    plt.ylim(0, 1.01)
    plt.grid(True)
    plt.legend(loc="lower left")
    plt.savefig(out_pic)

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('config', help='config file path')
    parser.add_argument('pkl_result_file', help='pkl result file path')
    parser.add_argument('--out', default='pr_curve.png')
    parser.add_argument('--eval', default='bbox')
    cfg = parser.parse_args()

    plot_pr_curve(config_file=cfg.config, result_file=cfg.pkl_result_file, out_pic=cfg.out, metric=cfg.eval)


输入命令:

python plot_pr_curve.py work_dirs/cascade_rcnn_r50_fpn_1x_coco.py results.pkl

11 查看完整config配置文件

python tools/misc/print_config.py work_dirs/cascade_rcnn_r50_fpn_1x_coco.py

12 核查数据增强的结果是否正确

python tools/misc/browse_dataset.py work_dirs/cascade_rcnn_r50_fpn_1x_coco.py  --output-dir work_dirs/

8、参考链接

https://blog.csdn.net/qq_35077107/article/details/124768460?spm=1001.2014.3001.5502

https://blog.csdn.net/weixin_44966641/article/details/124558532文章来源地址https://www.toymoban.com/news/detail-463234.html

到了这里,关于【MMDetection】——训练个人数据集的文章就介绍完了。如果您还想了解更多内容,请在右上角搜索TOY模板网以前的文章或继续浏览下面的相关文章,希望大家以后多多支持TOY模板网!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处: 如若内容造成侵权/违法违规/事实不符,请点击违法举报进行投诉反馈,一经查实,立即删除!

领支付宝红包 赞助服务器费用

相关文章

  • mmdetection训练自己的COCO数据集及常见问题

    训练自己的VOC数据集及常见问题见下文: mmdetection训练自己的VOC数据集及常见问题_不瘦8斤的妥球球饼的博客-CSDN博客_mmdetection训练voc 目录 一、环境安装 二、训练测试步骤 三、常见问题 batch size设置 学习率和epoch的修改 训练过程loss为nan的问题 GPU out of memory 保存最佳权重文件

    2024年02月06日
    浏览(65)
  • MMdetection 环境配置、config文件解析以及训练自定义VOC数据集

    MMDetection是针对目标检测任务推出的一个开源项目,它基于Pytorch实现了大量的目标检测算法,把数据集构建、模型搭建、训练策略等过程都封装成了一个个模块,通过模块调用的方式,我们能够以很少的代码量实现一个新算法,大大提高了代码复用率。本文记录一下关于MMd

    2024年02月14日
    浏览(36)
  • MMDetection3d对KITT数据集的训练与评估介绍

    如有错误,恳请指出。 在之后的时间内,可能会学习与点云相关的知识,进一步学习基于点云的3D目标检测。然后,为了快速入门这个领域,想使用mmdetection3d开源算法库来尝试训练一些经典的3d目标检测模型,比如:SECOND,PointPillars,3D-SSD等等。之后重点是详细介绍KITTI数据

    2024年02月02日
    浏览(41)
  • 【MMDetection3D】环境搭建,使用PointPillers训练&测试&可视化KITTI数据集

    2D卷不动了,来卷3D,之后更多地工作会放到3D检测上 本文将简单介绍什么是3D目标检测、KITTI数据集以及MMDetection3D算法库,重点介绍如何在MMDetection3D中,使用PointPillars算法训练KITTI数据集,并对结果进行测试和可视化。   对于一张输入图像,2D目标检测旨在给出物体类别并标

    2024年02月03日
    浏览(51)
  • mmdetection3d-之(三)--FCOS3d训练waymo数据集

    本内容分为两部分 1. waymo数据集转KITTI格式 2. FCOS3D训练KITTI格式的waymo数据集 1.1.1 waymo数据集下载 waymo数据集v1.2.0可以从这里下载。其中,train(32个压缩包),test(8个压缩包),val(8个压缩包)。这里的文件都是压缩包,每个都有20个G左右。 如果不想下载压缩包,可以下载

    2024年01月16日
    浏览(49)
  • mmdetection3d-之(一)--FCOS3d训练nuscenes-mini数据集

    参考网上的博客,出现各种错误,最大的是: AssertionError: Samples in split doesn\\\'t match samples in predictions. 给了解决方案,也不知道那个数字是怎么来的。索性自己来一遍,参考了github issue。   第一步,下载数据集并解压: 第二步,修改代码 tools/create_data.py   第三步,制作数据

    2024年02月15日
    浏览(49)
  • 【mmdetection】用自己的coco数据集训练mask r-cnn并进行验证、测试,推理可视化,更改backbone,只针对某一标签进行训练

    本人呕心沥血从无到有的摸索,自己边尝试边整理的,其实耐心多看官方文档确实能找到很多东西(下面有官方文档的链接这里就不重复粘贴了),也为了方便我自己copy语句嘻嘻~ 为什么不是用Windows,作为一个小白我一开始真的想用windows,因为我懒得配双系统,但是没办法

    2024年02月04日
    浏览(45)
  • 【 [mmdetection] 如何在训练中断后,接着上次训练?】

    最近由于不知名原因,在用 faster rcnn 训练一个大型数据集的时候,在 epoch= 20 的时候中断训练了.采用以下方式继续上次训练. 打开 train.py ,如图: 也就是说,训练时,最后加一个– resume from 参数,然后后面跟上次训练生成的最后一个权重文件( .pth )就可以了. 因此,在命令行输入以下语

    2024年02月09日
    浏览(38)
  • 【nnunet】个人数据训练心得

    GitHub代码: GitHub - MIC-DKFZ/nnUNet 十项医学分割数据集: Medical Segmentation Decathlon 注意:安装时一定不能使用魔法,否则会被伏地魔(False) 这里有几个铁汁,可以一起参考,以他们的为主,我的为辅,一起食用 (四:2020.07.28)nnUNet最舒服的训练教程(让我的奶奶也会用nnUNet(

    2024年02月16日
    浏览(40)
  • deeplabcut 简明教程(训练个人数据集)

    从github下载deeplabcut 然后cd 到该目录下 激活ipython  创建新项目 设置配置文件  根据设定的配置文件从视频中截取帧 标定数据集 检查数据集标定效果 创建训练数据集(这一步要在你训练的地方执行 本地/云端)  查看训练和设定训练参数 训练模型  模型评估(评估上一步所

    2024年02月05日
    浏览(36)

觉得文章有用就打赏一下文章作者

支付宝扫一扫打赏

博客赞助

微信扫一扫打赏

请作者喝杯咖啡吧~博客赞助

支付宝扫一扫领取红包,优惠每天领

二维码1

领取红包

二维码2

领红包