对 MODNet 其他模块的剪枝探索

这篇具有很好参考价值的文章主要介绍了对 MODNet 其他模块的剪枝探索。希望对大家有所帮助。如果存在错误或未考虑完全的地方,请大家不吝赐教,您也可以点击"举报违法"按钮提交疑问。

写在前面

先前笔者分享了《对 MODNet 主干网络 MobileNetV2的剪枝探索》,没想到被选为了CSDN每天值得看系列,因为笔者开设的专栏《MODNet-Compression探索之旅》仅仅只是记录笔者在模型压缩领域的探索历程,对此笔者深感荣幸,非常感谢官方大大的认可!!!接下来,笔者会加倍努力,创作更多优质文章,为社区贡献更多有价值、有意思的内容!!!!

对 MODNet 其他模块的剪枝探索,MODNet-Compression探索之旅,剪枝,算法,深度学习,人工智能,计算机视觉

本文将分享笔者对 MODNet 网络结构内部其他模块的剪枝探索,剪枝策略同前文主干网络是一样的,剪枝完成后对参数进行替换即可,接下来,就开启探索之旅吧~~

1 开展思路

  1. 访问 MODNet 获取模块;
  2. torch.save(model.state_dict(), path),并检测能否 load,注意参数;
  3. 修改替换脚本中 for 循环下的 if 条件判断;
  4. 修改backbone、MODNet中 IBNorm 以及 wrapper 中的 channels,run script;
  5. 加载替换后的模型参数,观察是否能够成功执行。

2 核心要义

  1. 模型分析:根据先前对剪枝后 MobileNet V2 的结构修改,以及嵌入 MODNet 后的 channel 修改情况,确定待修改的网络层;

  2. 通道裁剪:根据1得到的待修改的网络层进行裁剪,以满足结构与参数匹配的情况;

  3. 参数嵌入:确认 channel 匹配以后,将参入嵌入 MODNet;

3 探索过程

确定修改后的结构与原先的区别在于下列网络层:

  • backbone;
  • lr_branch中的 lr16x、lr8x;
  • hr_branch中 enc2x;

目前,已对 backbone 成功嵌入。

接下来,针对lr16x、lr8x进行剪枝处理,但通过观察可以发现,这两层的前面存在着 se_block 模块,因此,先对 se_block 进行处理。

3.1 se block

观察该部分在 MODNet 中的尺寸与网络层名称:

对 MODNet 其他模块的剪枝探索,MODNet-Compression探索之旅,剪枝,算法,深度学习,人工智能,计算机视觉

对 MODNet 其他模块的剪枝探索,MODNet-Compression探索之旅,剪枝,算法,深度学习,人工智能,计算机视觉

获取并替换成功!不过这部分详细的过程笔者没有记录!存在不周,请谅解~~

3.2 lr16x、lr8x

💥注意:由于起初缺乏对网络层的分析,因此,在进行这两层的嵌入时,仅仅只是单一的嵌入。

将lr16x嵌入以后,出现了“参数 shape > 结构 shape”的情况。

于是,笔者联想到先前的解决方案固定结构,重新进行参数替换。但即便如此,通过键值对获取参数时,参数中的通道数尺寸并未发生变化。(因此,先前的这种方法存在不合理性,但却在执行后可以成功匹配,目前还没有进一步探寻。)

合理的方案以及针对情况如下

  • 对于output channel:单独提取该层,进行剪枝。(但是,如果和它相连的下一层 input channel 也发生了变化,需要将其合并,同时处理,这样,上一次的输出决定着下一层的输入。
  • 对于input channel:如上,合并处理。但是,如果与该层相连的上一层channel保持不变,那就无法使用剪枝。目前的解决方案是,切片提取,先满足结构要求。

而 lr16x 与 lr8x 正适合第一种情况!

原结构:

对 MODNet 其他模块的剪枝探索,MODNet-Compression探索之旅,剪枝,算法,深度学习,人工智能,计算机视觉

修改后的结构:

对 MODNet 其他模块的剪枝探索,MODNet-Compression探索之旅,剪枝,算法,深度学习,人工智能,计算机视觉

将 lr16x 与 lr8x 作为一个 sequential,剪枝:

model = modnet.MODNet(backbone_pretrained=False)
pretrained_ckpt = 'modnet_photographic_portrait_matting.ckpt'
model.load_state_dict({k.replace('module.', ''): v for k, v in torch.load(pretrained_ckpt).items()})

# get model
model = nn.Sequential(model.lr_branch.conv_lr16x, model.lr_branch.conv_lr8x)
print(model)

# pruning
# 由于是针对lr16x的output以及lr8x的input,因此这里排除lr8x即可
config_list = [{'sparsity': 0.5,
                'op_types': ['Conv2d']},
               {'exclude': True,
                'op_names': ['1.layers.0']}
               ]

pruner = L1NormPruner(model, config_list)
_, masks = pruner.compress()
pruner._unwrap_model()
ModelSpeedup(model, dummy_input, masks).speedup_model()
print(model)

结构变化:

对 MODNet 其他模块的剪枝探索,MODNet-Compression探索之旅,剪枝,算法,深度学习,人工智能,计算机视觉

修改网络结构(mobilenet、wrapper、IBNorm),加载裁剪后的参数,能成功执行计算:

IBNorm结构变化,init部分:

    def __init__(self, in_channels):
        super(IBNorm, self).__init__()
        in_channels = in_channels

        # 针对lr_16x
        if in_channels == 48:
            self.bnorm_channels = 27
            self.inorm_channels = 21
        else:
            self.bnorm_channels = int(in_channels / 2)
            self.inorm_channels = in_channels - self.bnorm_channels 

加载:

model = modnet.MODNet(backbone_pretrained=False)
model = nn.Sequential(model.lr_branch.conv_lr16x, model.lr_branch.conv_lr8x)
model.load_state_dict(torch.load('test.pth'))

dummy_input = torch.randn([1, 1280, 32, 32])
flops, params, _ = count_flops_params(model, dummy_input, verbose=True)
print(f"Model FLOPs {flops / 1e6:.2f}M, Params {params / 1e6:.2f}M")

结果:

对 MODNet 其他模块的剪枝探索,MODNet-Compression探索之旅,剪枝,算法,深度学习,人工智能,计算机视觉

替换MODNet中,这一部分的参数,保存并加载:

对 MODNet 其他模块的剪枝探索,MODNet-Compression探索之旅,剪枝,算法,深度学习,人工智能,计算机视觉

3.3 enc2x

对 MODNet 其他模块的剪枝探索,MODNet-Compression探索之旅,剪枝,算法,深度学习,人工智能,计算机视觉

至此,三个模块的参数全部嵌入!

4 探索结果

4.1 模型大小

对 MODNet 其他模块的剪枝探索,MODNet-Compression探索之旅,剪枝,算法,深度学习,人工智能,计算机视觉

4.2 参数量与计算量

剪枝前 剪枝后
参数量 6.45 M 3.36 M
计算量 18117.07 M 15315.94 M

4.3 推理时延

序号 剪枝前 剪枝后
1 0.89 0.67
2 0.96 0.68
3 0.86 0.67

4.4 精度

评估指标 原模型 针对MobileNet V2剪枝后 微调后 从头训练后
MSE 0.004299 0.360781 0.140384 0.104005
MAD 0.008141 0.576560 0.211169 0.124459

5 实际推理测试

使用微调后的pth导出onnx模型:

model.eval()
batch_size = 1
height = 512
width = 512
dummy_input = Variable(torch.randn(batch_size, 3, height, width))

torch.onnx.export(
    model, dummy_input, 'test_modnet.onnx', export_params=True,
    input_names=['input'], output_names=['output'],
    dynamic_axes={'input': {0: 'batch_size', 2: 'height', 3: 'width'},
                  'output': {0: 'batch_size', 2: 'height', 3: 'width'}}, opset_version=11)

推理:

对 MODNet 其他模块的剪枝探索,MODNet-Compression探索之旅,剪枝,算法,深度学习,人工智能,计算机视觉

和微调前的推理结果并无差别,但在直接使用pth格式模型推理时差异较大。

为何会这样?难道是因为笔者选用的不是人像,而是天线宝宝?

在观察导出的 ONNX 格式模型时,笔者发现模型输出节点的个数发生了变化。

原因是笔者在导出时没有注意 output,使用官方脚本解决了~

💥注意:这也就告诉我们,模型导出时的成功提示并不一定是真正处理好了,很多内部细节的丢失会对模型的推理精度带来致命的效果,这时我们可以重新思考模型的输入与输出,或者采用可视化的方式进行查看!

再次推理:

对 MODNet 其他模块的剪枝探索,MODNet-Compression探索之旅,剪枝,算法,深度学习,人工智能,计算机视觉

虽然效果仍然不理想,但至少好了很多,而且可以看出来,笔者选用的测试样例确实不是人!

推理时延变化:240ms---> 192ms,有明显改进!


在导出时也遇到了一个error:

onnxruntime::UpsampleBase::ScalesValidation scale >= 1 was false. Scale value should be greater tha

分析原因:调用 torch.export 时未指定 op_version;

解决方案:考虑到 笔者的pytorch version>=1.3.1,因此直接指定其为op为11,完成了推理!文章来源地址https://www.toymoban.com/news/detail-814105.html

6 结论 

  1. 在替换除了 MobileNet V2 以外的其他部分时,没有考虑整体,仅仅只是对单一的卷积层剪枝,以致于相连的下一个卷积层无法修改通道数。因此,剪枝无法直接对 input channels 操作,只能针对 output channels,进而影响 input channels。
  2. 关于IBNorm,直接修改了channels,可以运行,但缺乏通用性!
  3. 成功嵌入了除 MobileNet V2 以外的参数,并成功导出 ONNX 模型,完成模型推理!
  4. 经测试,模型大小、参数量降低了一半,推理时延降低 20%,从模型压缩的轻量化角度来看,本次探索是成功的,但从模型本身的精度来看,还有很长一段路要走!

到了这里,关于对 MODNet 其他模块的剪枝探索的文章就介绍完了。如果您还想了解更多内容,请在右上角搜索TOY模板网以前的文章或继续浏览下面的相关文章,希望大家以后多多支持TOY模板网!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处: 如若内容造成侵权/违法违规/事实不符,请点击违法举报进行投诉反馈,一经查实,立即删除!

领支付宝红包 赞助服务器费用

相关文章

  • sklearn中决策树模块的剪枝参数ccp_alpha如何可视化调整

    决策树作为树模型中最经典的算法,根据训练数据生长并分裂叶子结点,容易过拟合。所以一般来说会考虑生长停止后进行剪枝,把一些不必要的叶子结点去掉(让其父结点作为叶子结点),这样或许对其泛化能力有积极作用。 在scikit-learn的决策树模块里,默认是不剪枝的,

    2024年02月21日
    浏览(33)
  • 【微服务】集成其他已有的模块

    集成完成

    2024年02月15日
    浏览(28)
  • springboot扫描不到其他模块下定义的Bean

    springboot默认是不能扫描到其他依赖模块定义的Bean的。(默认扫描的是启动类所在包下的所有Bean)也就是在项目启动的不能将其他模块的Bean加载到spring容器 项目之间要有联系性 admin模块为springboot框架,其他的只是普通的maven项目,admin 默认是无法扫描到 framework模块里面的

    2023年04月26日
    浏览(26)
  • orchard core 搭建cms 加载其他模块的管理1

    有一个具体的例子 :https://github.com/OrchardCMS/OrchardCore.Samples 1、先使用教程,安装cms -可以是完全 也可以是采用前后端分离管理。 修改对应的program.cs 的内容: `var builder = WebApplication.CreateBuilder(args); // Add services to the container. //builder.Services.AddRazorPages(); builder.Services.AddOrchardCor

    2024年02月08日
    浏览(23)
  • 【THM】Burp Suite:Other Modules(其他模块)-初级渗透测试

    除了广泛认可的Repeater和Intruder房间之外,Burp Suite 还包含几个鲜为人知的模块。这些将成为这个房间探索的重点。 重点将放在解码器、比较器、排序器和组织器工具上。它们促进了编码文本的操作,支持数据集的比较,允许分析捕获的令牌内的随机性,并帮助您存储和注释

    2024年04月11日
    浏览(32)
  • 启动报错:SpringBootApplication扫描不到其他模块下的bean问题导致postman接口报404

    当启动类在3包下,但是我们的代码写在1包下的2中,代码没有生效,发接口报404,原因是启动类4没有扫描到1包下的代码。 启动报错:SpringBootApplication扫描不到其他模块下的bean问题导致postman接口报404 提示:关联两个包 需要在3包下的pom.xml中引入2这个包

    2024年02月14日
    浏览(33)
  • 带你探索400G光模块测试

    随着移动互联网、云计算、大数据等技术快速发展,数据中心及云计算资源需求的爆发式地增长,核心网传输带宽需求大幅度的提升,同时也带动了超大规模云数据中心的发展,对数据中心内部和之间的互联的光模块带宽需求呈快速增长,促使数据中心从100G向更高速率、更大

    2024年02月08日
    浏览(24)
  • 《花雕学AI》Poe:一个让你和 AI 成为朋友的平台,带你探索 ChatGPT4 和其他 八种AI 模型的奥秘

    你是否曾经梦想过,能够在一个平台上,和多种不同的 AI 模型进行有趣、有用、有深度的对话,甚至还能轻松地把你的对话分享给其他人?如果你有这样的梦想,那么 Poe 一站式 AI 工具箱就是你的不二之选! Poe 是国外知名问答社区 Quora 推出的 AI 平台,它汇集了多个基于大

    2024年02月03日
    浏览(30)
  • C - Songs Compression

    Ivan has nn songs on his phone. The size of the ii-th song is a_iai​ bytes. Ivan also has a flash drive which can hold at most mm bytes in total. Initially, his flash drive is empty. Ivan wants to copy all nn songs to the flash drive. He can compress the songs. If he compresses the ii-th song, the size of the ii-th song reduces from a_iai​ t

    2024年02月09日
    浏览(23)
  • 网络安全B模块(笔记详解)- 隐藏信息探索

    1.访问服务器的FTP服务,下载图片QR,从图片中获取flag,并将flag提交; ​ 通过windows电脑自带的图片编辑工具画图将打乱的二维码分割成四个部分,然后将四个部分通过旋转、移动拼接成正确的二维码 ​ 使用二维码扫描工具CQR.exe扫描该二维码 ​ 获得一串base64的编码,进行

    2024年01月16日
    浏览(46)

觉得文章有用就打赏一下文章作者

支付宝扫一扫打赏

博客赞助

微信扫一扫打赏

请作者喝杯咖啡吧~博客赞助

支付宝扫一扫领取红包,优惠每天领

二维码1

领取红包

二维码2

领红包