模型优化之模型剪枝

这篇具有很好参考价值的文章主要介绍了模型优化之模型剪枝。希望对大家有所帮助。如果存在错误或未考虑完全的地方,请大家不吝赐教,您也可以点击"举报违法"按钮提交疑问。

一、概述
模型剪枝按照结构划分,主要包括结构化剪枝和非结构化剪枝:
(1)结构化剪枝:剪掉神经元节点之间的不重要的连接。相当于把权重矩阵中的单个权重值设置为0。
模型剪枝,模型优化,剪枝,深度学习,人工智能
(2)非结构化剪枝:把权重矩阵中某个神经元节点去掉,则和神经元相连接的突触也要全部去除。相当于同时去除权重矩阵中的某一行和列。如何判断神经元节点的重要程度呢?可以通过计算神经元对应的行和列的权重值的平方和的根的大小进行排序,把排序在后面一定比例的神经元节点去掉
模型剪枝,模型优化,剪枝,深度学习,人工智能
二、pytorch中模型剪枝:
Pytorch中模型的剪枝方法有三种,局部剪枝、全局剪枝和自定义剪枝。与剪枝有关的接口封装在torch.nn.utils.prune中。接下来开始演示三种剪枝在LeNet网络中的应用效果,我们首先给出LeNet网络结构。

import torch
from torch import nn

class LeNet(nn.Module):
    def __init__(self):
        super(LeNet, self).__init__()
        self.conv1 = nn.Conv2d(1, 6, 3)
        self.conv2 = nn.Conv2d(6, 16, 3)
        self.fc1 = nn.Linear(16 * 5 * 5, 120) 
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))
        x = F.max_pool2d(F.relu(self.conv2(x)), 2)
        x = x.view(-1, int(x.nelement() / x.shape[0]))
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

(1)局部剪枝
在本人理解就是一层一层的单独剪枝,下面代码还附有多参数多网络结构剪枝:

def part_cut(model):
    '''
    ######################################局部剪枝#########################################
    剪枝之后会产生一个mask
    剪枝api:prune.random_unstructured(layer1, name="weight", amount=0.3)
            amount:剪枝的比例
            layer1:需要剪的层对象
            name:指定剪的权重还是偏执
    剪枝固化api:prune.remove(layer1, 'weight')
            参数不用过多介绍,功能是剪枝后的模型固化(永久化)
    '''
    layer1 = model.conv1
    print("--------------------------------------剪枝前----------------------------------")
    # print(list(layer1.named_parameters()))
    # print(list(layer1.named_buffers()))
    prune.random_unstructured(layer1, name="weight", amount=0.3)
    print("--------------------------------------剪枝后-----------------------------------")
    # print(list(layer1.named_parameters()))
    # print(list(layer1.named_buffers()))
    prune.remove(layer1, 'weight')
    print("-------------------------------------模型固化后---------------------------------")
    # print(list(layer1.named_parameters()))
    # print(list(layer1.named_buffers()))


    '''-------------------------------------多参数多网络结构剪枝---------------------------------'''
    for name, module in model.named_modules():
        print(name,module)
        # prune 20% of connections in all 2D-conv layers
        if isinstance(module, torch.nn.Conv2d):
            prune.l1_unstructured(module, name='weight', amount=0.2)
            prune.remove(module, 'weight')
        # prune 40% of connections in all linear layers
        elif isinstance(module, torch.nn.Linear):
            prune.l1_unstructured(module, name='weight', amount=0.4)
            prune.remove(module, 'weight')

    print(dict(model.named_buffers()).keys())  # to verify that all masks exist
    return 0

(2)全局剪枝:
剪枝所占比例是按照所有参数来算的,不是按照每层的数量来算的,剪枝时候也按整体来算。

def glob_cut(model):
    '''
    全局剪枝:
    '''
    parameters_to_prune = (
        (model.conv1, 'weight'),
        (model.conv2, 'weight'),
        (model.fc1, 'weight'),
        (model.fc2, 'weight'),
        (model.fc3, 'weight'),
    )

    prune.global_unstructured(
        parameters_to_prune,
        pruning_method=prune.L1Unstructured,
        amount=0.6,
    )
    print(list(model.named_parameters()))
    print(list(model.named_buffers()))

(3)自定义剪枝
该方法不说了,饿了,要吃饭了,急的话参考下官方教程,最后有。文章来源地址https://www.toymoban.com/news/detail-624318.html

到了这里,关于模型优化之模型剪枝的文章就介绍完了。如果您还想了解更多内容,请在右上角搜索TOY模板网以前的文章或继续浏览下面的相关文章,希望大家以后多多支持TOY模板网!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处: 如若内容造成侵权/违法违规/事实不符,请点击违法举报进行投诉反馈,一经查实,立即删除!

领支付宝红包 赞助服务器费用

相关文章

觉得文章有用就打赏一下文章作者

支付宝扫一扫打赏

博客赞助

微信扫一扫打赏

请作者喝杯咖啡吧~博客赞助

支付宝扫一扫领取红包,优惠每天领

二维码1

领取红包

二维码2

领红包