剪枝与重参第七课:YOLOv8剪枝

这篇具有很好参考价值的文章主要介绍了剪枝与重参第七课:YOLOv8剪枝。希望对大家有所帮助。如果存在错误或未考虑完全的地方,请大家不吝赐教,您也可以点击"举报违法"按钮提交疑问。

YOLOv8剪枝

前言

手写AI推出的全新模型剪枝与重参课程。记录下个人学习笔记,仅供自己参考。

本次课程主要讲解YOLOv8剪枝。

课程大纲可看下面的思维导图

剪枝与重参第七课:YOLOv8剪枝

1.Overview

YOLOV8剪枝的流程如下:

剪枝与重参第七课:YOLOv8剪枝

结论:在VOC2007上使用yolov8s模型进行的实验显示,预训练和约束训练在迭代50个epoch后达到了相同的mAP(:0.5)值,约为0.77。剪枝后,微调阶段需要65个epoch才能达到相同的mAP50。修建后的ONNX模型大小从43M减少到36M。

注意:我们需要将网络结构和网络权重区分开来,YOLOv8的网络结构来自yaml文件,如果我们进行剪枝后保存的权重文件的结构其实是和原始的yaml文件不符合的,需要对yaml文件进行修改满足我们的要求。

2.Pretrain(option)

步骤如下:

  • git clone https://github.com/ultralytics/ultralytics.git
  • use VOC2007, and modify the VOC.yaml(去除VOC2012的相关内容)
  • disable amp(禁用amp混合精度)
# FILE: ultralytics/yolo/engine/trainer.py
...
def check_amp(model):
    # Avoid using mixed precision to affect finetune
    return False # <============== modified(修改部分)
    device = next(model.parameters()).device  # get model device
    if device.type in ('cpu', 'mps'):
        return False  # AMP only used on CUDA devices

    def amp_allclose(m, im):
        # All close FP32 vs AMP results
    ...

3.Constrained Training

约束训练是为了筛选哪些channel比较重要,哪些channel没有那么重要,也就是我们上节课所说的稀疏训练

  • prune the BN layer by adding L1 regularizer.
# FILE: ultralytics/yolo/engine/trainer.py
...
# Backward
self.scaler.scale(self.loss).backward()

# <============ added(新增)
l1_lambda = 1e-2 * (1 - 0.9 * epoch / self.epochs)
for k, m in self.model.named_modules():
    if isinstance(m, nn.BatchNorm2d):
        m.weight.grad.data.add_(l1_lambda * torch.sign(m.weight.data))
        m.bias.grad.data.add_(1e-2 * torch.sign(m.bias.data))

# Optimize - https://pytorch.org/docs/master/notes/amp_examples.html
if ni - last_opt_step >= self.accumulate:
    self.optimizer_step()
    last_opt_step = ni
...

注意1:在剪枝时,我们选择加载last.pt而非best.pt,因为由于迁移学习,模型的泛化性比较好,在第一个epoch时mAP值最大,但这并不是真实的,我们需要稳定下来的一个模型进行prune

注意2:我们在对Conv层进行剪枝时,我们只考虑1v1(如BottleNeck,C2f and SPPF)、1vm(如Backbone,Detect)的情形,并不考虑mv1的情形。

思考:Constrained Training的必要性?

约束训练可以使得模型更易于剪枝。在约束训练中,模型会学习到一些通道或者权重系数比较不重要的信息,而这些信息在剪枝过程中得到应用,从而达到模型压缩的效果。而如果直接进行剪枝操作,可能会出现一些问题,比如剪枝后的模型精度大幅下降、剪枝不均匀等。因此,在进行剪枝操作之前,通过稀疏训练的方式,可以更好地准确地确定哪些通道或者权重系数可以被剪掉,从而避免上述问题的发生。

4.Prune

4.1 检查BN层的bias

  • 剪枝后,确保BN层的大部分bias足够小(接近于0),否则重新进行稀疏训练
for name, m in model.named_modules():
    if isinstance(m, torch.nn.BatchNorm2d):
        w = m.weight.abs().detach()
        b = m.bias.abs().detach()
        ws.append(w)
        bs.append(b)
        print(name, w.max().item(), w.min().item(), b.max().item(), b.min().item())

4.2 设置阈值和剪枝率

  • threshold:全局或局部
  • factor:保持率,裁剪太多不推荐
factor = 0.8
ws = torch.cat(ws)
threshold = torch.sort(ws, descending=True)[0][int(len(ws) * factor)]
print(threshold)

4.3 最小剪枝Conv单元的TopConv

Top-Bottom Conv如下图所示:

剪枝与重参第七课:YOLOv8剪枝

TopConv剪枝的示例代码如下:

def prune_conv(conv1: Conv, conv2: Conv):
    gamma = conv1.bn.weight.data.detach()
    beta  = conv1.bn.bias.data.detach()

    keep_idxs = []    
    local_threshold = threshold

    while len(keep_idxs) < 8:
        keep_idxs = torch.where(gamma.abs() >= local_threshold)[0] 
        local_threshold = local_threshold * 0.5

    n = len(keep_idxs)
    print(n / len(gamma) * 100)  # 打印我们保留了多少的channel
    
    # prune
    conv1.bn.weight.data = gamma[keep_idxs]
    conv1.bn.bias.data   = beta[keep_idxs]
    conv1.bn.running_mean.data = conv1.bn.running_mean.data[keep_idxs]
    conv1.bn.running_var.data  = conv1.bn.running_var.data[keep_idxs]
    conv1.bn.num_features   = n
    conv1.conv.weight.data  = conv1.conv.weight.data[keep_idxs]
    conv1.conv.out_channels = n

    if conv1.conv.bias is not None:
        conv1.conv.bias.data = conv1.conv.bias.data[keep_idxs]

# pattern to prune
# 1. prune all 1 vs 1 TB pattern e.g. bottleneck
for name, m in model.named_modules():
    if isinstance(m, Bottleneck):
        prune_conv(m.cv1, m.cv2)

注意:由于NVIDIA的硬件加速的原因,我们保留的channels应该大于等于8,我们可以通过设置local_threshold,尽量小点,让更多的channel保留下来。

4.4 最小剪枝Conv单元的BottomConv

BottomConv剪枝的示例代码如下:

def prune_conv(conv1: Conv, conv2: Conv):
    ...
    if not isinstance(conv2, list):
        conv2 = [conv2]
    
    for item in conv2:
        if item is not None:
            if isinstance(item, Conv):
                conv = item.conv
            else:
                conv = item
            conv.in_channels = n
            conv.weight.data = conv.weight.data[:, keep_idxs]

注意BottomConv存在两种情形,一种是单个Conv,还有一种是Conv列表。TopConv需要考虑conv2d+bn+act,而BottomConv只需要考虑conv2d

4.5 Seq剪枝

Seq剪枝的示例代码如下:

def prune(m1, m2):
    if isinstance(m1, C2f):
        m1 = m1.cv2
    
    if not isinstance(m2, list):
        m2 = [m2]
    
    for i, item in enumerate(m2):
        if isinstance(item, C2f) or isinstance(item, SPPF):
            m2[i] = item.cv1
    
    prune_conv(m1, m2)

# 2. prune sequential
seq = model.model
for i in range(3, 9):
    if i in [6, 4, 9]: continue
    prune(seq[i], seq[i+1])

注意:我们不考虑1vm的情形,因此在4,6,9的module我们是不进行剪枝的,后续head进行Concat时是对4,6,9的module进行拼接的。考虑到前几个conv的特征提取的重要性,我们也不剪枝它们。(那感觉没剪几个module呀😂)

4.6 Detect-FPN剪枝

Detect-FPN剪枝的示例代码如下:

# 3. prune FPN related block
detect: Detect = seq[-1]

last_inputs = [seq[15], seq[18], seq[21]]
colasts     = [seq[16], seq[19], None]

for last_input, colast, cv2, cv3 in zip(last_inputs, colasts, detect.cv2, detect.cv3):
    prune(last_input, [colast, cv2[0], cv3[0]])
    prune(cv2[0], cv2[1])
    prune(cv2[1], cv2[2])
    prune(cv3[0], cv3[1])
    prune(cv3[1], cv3[2])

for name, p in yolo.model.named_parameters():
    p.requires_grad = True

注意:一定要设置所有参数为需要训练。因为加载后的model会给弄成False,导致报错

4.7 完整示例代码

完整的示例代码如下:

from ultralytics import YOLO
import torch
from ultralytics.nn.modules import Bottleneck, Conv, C2f, SPPF, Detect

# Load a model
yolo = YOLO("epoch-50-constrained_weights/last.pt")  # build a new model from scratch
model = yolo.model

ws = []
bs = []

for name, m in model.named_modules():
    if isinstance(m, torch.nn.BatchNorm2d):
        w = m.weight.abs().detach()
        b = m.bias.abs().detach()
        ws.append(w)
        bs.append(b)
        print(name, w.max().item(), w.min().item(), b.max().item(), b.min().item())
# keep
factor = 0.8
ws = torch.cat(ws)
threshold = torch.sort(ws, descending=True)[0][int(len(ws) * factor)]
print(threshold)


def prune_conv(conv1: Conv, conv2: Conv):
    gamma = conv1.bn.weight.data.detach()
    beta  = conv1.bn.bias.data.detach()
    # if gamma.abs().min() > f or beta.abs().min() > 0.1:
    #     return
    
    # idxs = torch.argsort(gamma.abs() * coeff + beta.abs(), descending=True)
    keep_idxs = []
    local_threshold = threshold
    while len(keep_idxs) < 8:
        keep_idxs = torch.where(gamma.abs() >= local_threshold)[0]
        local_threshold = local_threshold * 0.5
    n = len(keep_idxs)
    # n = max(int(len(idxs) * 0.8), p)
    print(n / len(gamma) * 100)
    # scale = len(idxs) / n
    conv1.bn.weight.data = gamma[keep_idxs]
    conv1.bn.bias.data   = beta[keep_idxs]
    conv1.bn.running_var.data = conv1.bn.running_var.data[keep_idxs]
    conv1.bn.running_mean.data = conv1.bn.running_mean.data[keep_idxs]
    conv1.bn.num_features = n
    conv1.conv.weight.data = conv1.conv.weight.data[keep_idxs]
    conv1.conv.out_channels = n
    
    if conv1.conv.bias is not None:
        conv1.conv.bias.data = conv1.conv.bias.data[keep_idxs]

    if not isinstance(conv2, list):
        conv2 = [conv2]
        
    for item in conv2:
        if item is not None:
            if isinstance(item, Conv):
                conv = item.conv
            else:
                conv = item
            conv.in_channels = n
            conv.weight.data = conv.weight.data[:, keep_idxs]
    
def prune(m1, m2):
    if isinstance(m1, C2f):      # C2f as a top conv
        m1 = m1.cv2
    
    if not isinstance(m2, list): # m2 is just one module
        m2 = [m2]
        
    for i, item in enumerate(m2):
        if isinstance(item, C2f) or isinstance(item, SPPF):
            m2[i] = item.cv1
    
    prune_conv(m1, m2)

for name, m in model.named_modules():
    if isinstance(m, Bottleneck):
        prune_conv(m.cv1, m.cv2)
        
seq = model.model
for i in range(3, 9):
    if i in [6, 4, 9]: continue
    prune(seq[i], seq[i+1])
    
detect:Detect = seq[-1]
last_inputs   = [seq[15], seq[18], seq[21]]
colasts       = [seq[16], seq[19], None]
for last_input, colast, cv2, cv3 in zip(last_inputs, colasts, detect.cv2, detect.cv3):
    prune(last_input, [colast, cv2[0], cv3[0]])
    prune(cv2[0], cv2[1])
    prune(cv2[1], cv2[2])
    prune(cv3[0], cv3[1])
    prune(cv3[1], cv3[2])

# ***step4,一定要设置所有参数为需要训练。因为加载后的model他会给弄成false。导致报错
# pipeline:
# 1. 为模型的BN增加L1约束,lambda用1e-2左右
# 2. 剪枝模型,比如用全局阈值
# 3. finetune,一定要注意,此时需要去掉L1约束。最终final的版本一定是去掉的
for name, p in yolo.model.named_parameters():
    p.requires_grad = True
    
# 1. 不能剪枝的layer,其实可以不用约束
# 2. 对于低于全局阈值的,可以删掉整个module
# 3. keep channels,对于保留的channels,他应该能整除n才是最合适的,否则硬件加速比较差
#    n怎么选,一般fp16时,n为8
#                int8时,n为16
#    cp.async.cg.shared
#

yolo.val()
# yolo.export(format="onnx")
# yolo.train(data="VOC.yaml", epochs=100)
print("done")

5.YOLOv8剪枝总结

关于yolov8剪枝有以下几点值得注意:

Pipeline:

    1. 为模型的BN增加L1约束,lambda用1e-2左右
    1. 剪枝模型使用的是全局阈值
    1. finetune模型时,一定要注意,此时需要去掉L1约束,最终的final的版本一定是去掉的(ultralytics/yolo/engine/trainer.py中注释)
    1. 对于yolo.model.named_parameters()循环,需要设置p.requires_gradTrue

Future work:

    1. 不能剪枝的layer,其实可以不用约束
    1. 对于低于全局阈值的,可以删掉整个module
    1. keep channels,对于保留的channels,它应该能整除n才是最合适的,否则硬件加速比较差
  • n怎么选呢?一般fp16时,n为8;int8时,n为16

总结

本次课程学习了YOLOv8的剪枝,主要是对前面剪枝课程的一个总结和实现吧,大体流程就是稀疏训练后进行剪枝最后微调,看着虽然简单,但实际细节把控还是非常多的,比如说哪些部分好剪,哪些部分不好剪,剪枝的过程中如何通过model获取想要prune的module等等,需要对YOLOv8整体网络结构和对ONNX模型的操作非常熟练,这还只是基础理论,实操部分的坑还没踩呢,在之后好好练习练习吧😄文章来源地址https://www.toymoban.com/news/detail-421417.html

到了这里,关于剪枝与重参第七课:YOLOv8剪枝的文章就介绍完了。如果您还想了解更多内容,请在右上角搜索TOY模板网以前的文章或继续浏览下面的相关文章,希望大家以后多多支持TOY模板网!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处: 如若内容造成侵权/违法违规/事实不符,请点击违法举报进行投诉反馈,一经查实,立即删除!

领支付宝红包 赞助服务器费用

相关文章

  • 复习第七课 C语言-指针数组,函数,string

    目录 【1】指针和数组 【2】数组指针 【3】指针数组 【4】函数 【5】函数传参 【6】动态开辟堆区空间 【7】string函数族 【8】递归函数 练习: 直接访问:通过数组名访问 间接访问:通过指针访问 》1. 一维数组 运算方法: 1) *和++都是弹幕运算符,优先级相同 2) 单目运算

    2024年02月16日
    浏览(42)
  • 隐语第七课:多方安全分析语言SCQL架构详解

    目录 1、SCQL的定义 2、SCQL主要特点 3、SCQL的实现原理   4、SCQL对SQL的支持能力 5、列控制列表CCL 5.1 什么是CCL 5.2 CCL的设计缘由 5.3 CCL约束详解 6、SCQL的使用场景 7、下面分享Slides 首先必须感谢蚂蚁集团及隐语社区带来的学习机会!  第七课由操顺德老师讲述,核心内容是 详解

    2024年04月11日
    浏览(57)
  • STM32第七课:PWM控制SG90舵机

            学习完上一课的PWM控制LED小灯实现呼吸灯的效果,我们就可以进一步学习PWM控制舵机的效果了。PWM控制舵机相信会是一个更有意思的小实验的。          舵机是一种位置(角度)伺服的驱动器,适用于那些需要角度不断变化并可以保持的控制系统。目前在高档

    2024年04月17日
    浏览(32)
  • ACL 访问控制 过滤数据 维护网络安全(第七课)

    ACL是Access Control List(访问控制列表)的缩写,是一种用于控制文件、目录、网络设备等资源访问权限的方法。ACL可以对每个用户或用户组设置不同的访问权,即在访问控制清单中为每个用户或用户组指定允许或禁止访问该资源的权限。它通常由一系列规则组成,规则定义了一

    2024年02月10日
    浏览(48)
  • 吴恩达llama课程笔记:第七课llama安全工具

     羊驼Llama是当前最流行的开源大模型,其卓越的性能和广泛的应用领域使其成为业界瞩目的焦点。作为一款由Meta AI发布的开放且高效的大型基础语言模型,Llama拥有7B、13B和70B(700亿)三种版本,满足不同场景和需求。 吴恩达教授推出了全新的Llama课程,旨在帮助学习者全面

    2024年04月25日
    浏览(34)
  • 单片机第三季-第七课:STM32中断体系

    目录 1,NVIC 2,中断和事件的区别 3,优先级的概念  4,如何实际编程使用外部中断 5,STM32开发板通过按键控制LED  5.1,打开相应GPIO模块时钟 5.2,NVIC设置 5.3,外部中断线和配套的GPIO进行连接映射 5.4,代码文件  6,FSMC NVIC: Nested Vector Interrupt Control,嵌套向量中断控制器;

    2024年01月18日
    浏览(47)
  • React框架第七课 语法基础课《第一课React你好世界》

    从这一课开始真正进入到React框架的基础语法学习,之前的前五课做个了解即可。 ├── README.md 使用方法的文档 ├── node_modules 所有的依赖安装的目录 ├── package-lock.json 锁定安装时的包的版本号,保证团队的依赖能保证一致。 ├── package.json ├── public 静态公共目录

    2024年02月02日
    浏览(43)
  • C语言第七课----------函数的定义及使用--------C语言重要一笔

                                                      个人主页::小小页面                  gitee页面:秦大大                 一个爱分享的小博主 欢迎小可爱们前来借鉴 __________________________________________________________          1.函数是什么   

    2024年02月16日
    浏览(33)
  • YOLOv5剪枝✂️| 模型剪枝实战篇

    本篇博文所用代码为开源项目修改得到,且不适合基础太差的同学。 本篇文章主要讲解代码的使用方式,手把手带你实现YOLOv5模型剪枝操作。 0. 环境准备 终端键入:

    2024年02月05日
    浏览(55)
  • YOLOV7剪枝流程

    1)划分数据集进行训练前的准备,按正常的划分流程即可 2)修改train.py文件 第一次处在参数列表里添加剪枝的参数,正常训练时设置为False,剪枝后微调时设置为True 第二处位置在 下面修改代码,源代码为: 修改为: 第三处位置在 源代码为: 修改为 第四处位置在 源代码

    2024年01月22日
    浏览(34)

觉得文章有用就打赏一下文章作者

支付宝扫一扫打赏

博客赞助

微信扫一扫打赏

请作者喝杯咖啡吧~博客赞助

支付宝扫一扫领取红包,优惠每天领

二维码1

领取红包

二维码2

领红包