yolov5使用知识蒸馏

这篇具有很好参考价值的文章主要介绍了yolov5使用知识蒸馏。希望对大家有所帮助。如果存在错误或未考虑完全的地方,请大家不吝赐教,您也可以点击"举报违法"按钮提交疑问。

提示:本文采用的蒸馏方式为 Distilling Object Detectors with Fine-grained Feature Imitation 这篇文章


前言

提示:这里可以添加本文要记录的大概内容:

本文介绍的论文《Distilling Object Detectors with Fine-grained Feature Imitation》即是基于 Fine-grained Feature Imitation 技术的目标检测知识蒸馏方法。该方法将 Fine-grained Feature Imitation 应用于学生模型的中间层,以捕捉更丰富的特征信息。通过在训练过程中引入目标检测任务的监督信号,Fine-grained Feature Imitation 技术可以更好地保留复杂模型中的细节特征,从而提高了轻量级模型的性能。


提示:以下是本篇文章正文内容,下面案例可供参考

一、Distilling Object Detectors with Fine-grained Feature Imitation 论文介绍

示例:pandas 是基于NumPy 的一种工具,该工具是为了解决数据分析任务而创建的。

1.创新点

Fine-grained Feature Imitation 技术可以概括为以下三个步骤:

  1. 利用复杂模型的中间层作为特征提取器,并用它提取学生模型的中间层的特征。

  2. 利用 Fine-grained Feature Imitation 技术对特征进行蒸馏,使学生模型能够学习到更丰富的特征信息。

  3. 在训练过程中引入目标检测任务的监督信号,以更好地保留复杂模型中的细节特征。

其核心思想是 teacher 网络中需要传递给 student 网络的应该是有效信息,而非无效的 background 信息。

2.内容介绍

1. Fine-Gained区域提取

yolov5使用知识蒸馏
上图中的红色和绿色边界框是在相应位置上的锚框。红色 anchor 表示与 gt 的边界框重叠最大,绿色 anchor 表示附近的物体样本。蒸馏时并不是对所有的anchor蒸馏,而是对gt框附近的anchor进行蒸馏,对于backbone输出的特征图,假设尺度为H X W,
网络中使用的anchor数量为K, 具体执行步骤如下:

  1. 对于给定的特征图,生成H X W X K 个anchor, 并计算与gt anchor的IOU值m,
  2. 计算最大的IOU值 M = max(m), 引入参数阈值因子Ψ, 计算过滤阈值F = M x Ψ,
    利用F进行IOU过滤,这里只保留大于F的部分,计算之后得到一个mask, 尺度为H X W.

2. loss 损失值

yolov5使用知识蒸馏
损失函数部分由两块组成,一块为Fine-grained Feature Imitation 损失,另一块为目标检测的分类和回归损失,
yolov5使用知识蒸馏
yolov5使用知识蒸馏

论文中展示了实验的对比结果,原论文是基于Faster Rcnn算法进行蒸馏,因此本文选择基于yolov5算法进行蒸馏。

二、yolov5 添加知识蒸馏

1.部分代码展示

调整gt anchors转换为相对于原图的位置

def make_gt_boxes(gt_boxes, max_num_box, batch, img_size):
    new_gt_boxes = []
    for i in range(batch):
        # 获取第i个batch的所有真实框
        boxes = gt_boxes[gt_boxes[:, 0] == i]
        # 真实框的个数
        num_boxes = boxes.size(0)
        if num_boxes < max_num_box:
            gt_boxes_padding = torch.zeros([max_num_box, gt_boxes.size(1)], dtype=torch.float)
            gt_boxes_padding[:num_boxes, :] = boxes
        else:
            gt_boxes_padding = boxes[:max_num_box]
        new_gt_boxes.append(gt_boxes_padding.unsqueeze(0))
    new_gt_boxes = torch.cat(new_gt_boxes)
    # transfer [x, y, w, h] to [x1, y1, x2, y2]
    new_gt_boxes_aim = torch.zeros(size=new_gt_boxes.size())
    new_gt_boxes_aim[:, :, 2] = (new_gt_boxes[:, :, 2] - 0.5 * new_gt_boxes[:, :, 4]) * img_size[1]
    new_gt_boxes_aim[:, :, 3] = (new_gt_boxes[:, :, 3] - 0.5 * new_gt_boxes[:, :, 5]) * img_size[0]
    new_gt_boxes_aim[:, :, 4] = (new_gt_boxes[:, :, 2] + 0.5 * new_gt_boxes[:, :, 4]) * img_size[1]
    new_gt_boxes_aim[:, :, 5] = (new_gt_boxes[:, :, 3] + 0.5 * new_gt_boxes[:, :, 5]) * img_size[0]
    return new_gt_boxes_aim

计算掩码 mask

def getMask(batch_size, gt_boxes, img_size, feat, anchors, max_num_box, device):
    # [b, K, 4]
    gt_boxes = make_gt_boxes(gt_boxes, max_num_box, batch_size, img_size)
    # 原图相对于当前特征图的步长
    feat_stride = img_size[0] / feat.size(2)
    anchors = torch.from_numpy(generate_anchors(feat_stride, anchors))
    feat = feat.cpu()
    height, width = feat.size(2), feat.size(3)
    feat_height, feat_width = feat.size(2), feat.size(3)
    shift_x = np.arange(0, feat_width) * feat_stride
    shift_y = np.arange(0, feat_height) * feat_stride
    shift_x, shift_y = np.meshgrid(shift_x, shift_y)
    shifts = torch.from_numpy(np.vstack((shift_x.ravel(), shift_y.ravel(),
                                         shift_x.ravel(), shift_y.ravel())).transpose())
    shifts = shifts.contiguous().type_as(feat).float()

    # num of anchors [3]
    A = anchors.size(0)
    K = shifts.size(0)

    anchors = anchors.type_as(gt_boxes)
    # all_anchors [K, A, 4]
    all_anchors = anchors.view(1, A, 4) + shifts.view(K, 1, 4)
    all_anchors = all_anchors.view(K * A, 4)
    # compute iou [all_anchors, gt_boxes]
    IOU_map = bbox_overlaps_batch(all_anchors, gt_boxes, img_size).view(batch_size, height, width, A, gt_boxes.shape[1])

    mask_batch = []
    for i in range(batch_size):
        max_iou, _ = torch.max(IOU_map[i].view(height * width * A, gt_boxes.shape[1]), dim=0)
        mask_per_im = torch.zeros([height, width], dtype=torch.int64).to(device)
        for k in range(gt_boxes.shape[1]):
            if torch.sum(gt_boxes[i][k]) == 0:
                break
            max_iou_per_gt = max_iou[k] * 0.5
            mask_per_gt = torch.sum(IOU_map[i][:, :, :, k] > max_iou_per_gt, dim=2)
            mask_per_im += mask_per_gt.to(device)
        mask_batch.append(mask_per_im)
    return mask_batch

计算imitation损失

def compute_mask_loss(mask_batch, student_feature, teacher_feature, imitation_loss_weight):
    mask_list = []
    for mask in mask_batch:
        mask = (mask > 0).float().unsqueeze(0)
        mask_list.append(mask)
    # [batch, height, widt
    mask_batch = torch.stack(mask_list, dim=0)
    norms = mask_batch.sum() * 2
    mask_batch_s = mask_batch.unsqueeze(4)
    no = student_feature.size(-1)
    bs, na, height, width, _ = mask_batch_s.shape
    mask_batch_no = mask_batch_s.expand((bs, na, height, width, no))
    sup_loss = (torch.pow(teacher_feature - student_feature, 2) * mask_batch_no).sum() / norms
    sup_loss = sup_loss * imitation_loss_weight
    return sup_loss

总结

完整代码请查看GitHub,麻烦动动小手点亮一下star
https://github.com/xing-bing文章来源地址https://www.toymoban.com/news/detail-401165.html

到了这里,关于yolov5使用知识蒸馏的文章就介绍完了。如果您还想了解更多内容,请在右上角搜索TOY模板网以前的文章或继续浏览下面的相关文章,希望大家以后多多支持TOY模板网!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处: 如若内容造成侵权/违法违规/事实不符,请点击违法举报进行投诉反馈,一经查实,立即删除!

领支付宝红包 赞助服务器费用

相关文章

  • YOLOv5训练数据提示No labels found、with_suffix使用、yolov5训练时出现WARNING: Ignoring corrupted image and/or label

    仔细看下数据加载、处理的文件datasets.py,发现有一句会根据第2步中images文件夹的位置找到对应labels文件夹: YOLOv5加载标签的地方在 datasets.py 中的这个地方,我们修改一下加载label的路径为自己的label放置位置就好。 在这个 img2label_paths 函数中,我们的修改如下:【因为我们

    2024年02月04日
    浏览(54)
  • YOLOv5、v7改进之二十八:ICLR 2022涨点神器——即插即用的动态卷积ODConv

    前 言: 作为当前先进的深度学习目标检测算法YOLOv5、v7系列算法,已经集合了大量的trick,但是在处理一些复杂背景问题的时候,还是容易出现错漏检的问题。此后的系列文章,将重点对YOLO系列算法的如何改进进行详细的介绍,目的是为了给那些搞科研的同学需要创新点或者

    2024年02月05日
    浏览(35)
  • 目标检测算法——YOLOv5/v7/v8改进结合即插即用的动态卷积ODConv(小目标涨点神器)

    作者将CondConv中一个维度上的动态特性进行了扩展,同时了考虑了空域、输入通道、输出通道等维度上的动态性,故称之为 全维度动态卷积 。ODConv通过并行策略采用多维注意力机制沿核空间的四个维度学习互补性注意力。 作为一种“即插即用”的操作,它可以轻易的嵌入到

    2024年01月19日
    浏览(50)
  • YOLOv5基础知识入门(2)— YOLOv5核心基础知识讲解

    前言: Hello大家好,我是小哥谈。 YOLOV4出现之后不久,YOLOv5横空出世。YOLOv5在YOLOv4算法的基础上做了进一步的改进,使检测性能得到更进一步的提升。YOLOv5算法作为目前工业界使用的最普遍的检测算法,存在着很多可以学习的地方。本文将对YOLOv5检测算法的核心基础知识进行

    2024年02月14日
    浏览(52)
  • 知识蒸馏实战:使用CoatNet蒸馏ResNet

    知识蒸馏(Knowledge Distillation),简称KD,将已经训练好的模型包含的知识(”Knowledge”),蒸馏(“Distill”)提取到另一个模型里面去。Hinton在\\\"Distilling the Knowledge in a Neural Network\\\"首次提出了知识蒸馏(暗知识提取)的概念,通过引入与教师网络(Teacher network:复杂、但预测精度优

    2024年02月06日
    浏览(50)
  • Yolov5一些知识

    Yolov5官方代码中,给出的目标检测网络中一共有4个版本,分别是Yolov5s、Yolov5m、Yolov5l、Yolov5x四个模型。 eg:Yolov5s Yolov3的网络结构是比较经典的 one-stage 结构,分为 输入端、Backbone、Neck和Prediction 四个部分 Yolov4在Yolov3的基础上进行了很多的创新。 比如: 输入端 ,主要包括 Mo

    2024年02月12日
    浏览(33)
  • 如何使YOLOv5在检测到目标后进行声音告警提示?

    导师有一个异常行为检测的小任务(吸烟行为检测),给我让我和师弟一起去完成。本身以为在YOLOv5的detect.py检测脚本中加入语音提示很简单,但是其中的过程却是一言难尽。 这也是查阅了很多资料,尝试过了各种大佬分享的经验,集百家之长完成了这个任务,感谢CSDN中各

    2024年01月19日
    浏览(44)
  • YOLOv5基础知识点——性能指标

    目标检测(object detection)=what + where Localization+Recongnition 类别标签(category label) 置信度得分(confidence score) 最小外接矩形(bounding box) 定位是找到检测图像中带有一个给定标签的单个目标; 检测是寻找到图像中带有给定标签的所有目标 目标检测性能指标= 检测精度+检测速

    2024年02月05日
    浏览(39)
  • YOLOv5基础知识点——激活函数

    ​​​​​​​什么是激活函数该选哪种激活函数?_哔哩哔哩_bilibili 深度学习笔记:如何理解激活函数?(附常用激活函数) - 知乎 (zhihu.com)  详解激活函数(Sigmoid/Tanh/ReLU/Leaky ReLu等) - 知乎 (zhihu.com) 算法面试问题二(激活函数相关)【这些面试题你都会吗】 - 知乎 (zhi

    2024年02月09日
    浏览(52)
  • YOLOv5基础知识入门(3)— 目标检测相关知识点

      前言 : Hello大家好,我是小哥谈。 YOLO算法发展历程和YOLOv5核心基础知识学习完成之后,接下来我们就需要学习目标检测相关知识了。为了让大家后面可以顺利地用YOLOv5进行目标检测实战,本节课就带领大家学习一下目标检测的基础知识点,希望大家学习之后有所收获!

    2024年02月13日
    浏览(39)

觉得文章有用就打赏一下文章作者

支付宝扫一扫打赏

博客赞助

微信扫一扫打赏

请作者喝杯咖啡吧~博客赞助

支付宝扫一扫领取红包,优惠每天领

二维码1

领取红包

二维码2

领红包