MindSpore-FCOS模型权重迁移推理对齐实录

这篇具有很好参考价值的文章主要介绍了MindSpore-FCOS模型权重迁移推理对齐实录。希望对大家有所帮助。如果存在错误或未考虑完全的地方,请大家不吝赐教,您也可以点击"举报违法"按钮提交疑问。

准备工作

环境:
wsl2 Ubuntu 20.04
mindspore 2.0.0
python 3.8
pytorch 2.0.1 cpu

基于已有的mindspore FCOS项目和FCOS官方pytorch权重来做迁移,

  • FCOS官方pytorch实现
    FCOS_imprv_R_50_FPN_1x权重
  • MindSpore FCOS项目链接
    该代码是mindspore1.6实现的,用新版本运行会有很多warning,warning的接口要更改为新的。
    而且没提供训练好的权重,所以用官方的pytorch权重进行迁移,但其中发现MindSpore相比官方有许多地方不同。

权重转换

迁移其实就是在做权重的键值映射对齐,这其中有一些规律可寻,但不多,更多需要自己的分析比对,建立映射字典。

可参考的经验:

  • https://gitee.com/lirongxi4/pt2ms_convert
    一个迁移脚本,通用性一般
  • https://mindspore.cn/docs/zh-CN/r2.0/migration_guide/overview.html
    MindSpore官方的迁移指南

根据上述迁移经验,打印两种框架的权重的名称及shape进行比对,总结名称转换方式如下(pytorch的名称改为mindspore的):

import copy, torch
import mindspore as ms

def fcos_pth2ckpt():
    m = ms.load_checkpoint('test.ckpt')  # mindspore FCOS保存的随机权重
    t = torch.load('./weights/FCOS_imprv_R_50_FPN_1x.pth', map_location=torch.device('cpu'))  # pytorch FCOS权重
    match_pt_kv = {}  # 匹配到的pt权重的name及value的字典
    match_pt_kv_mslist = []  # 匹配到的pt权重的name及value的字典, mindspore加载权重需求的格式
    not_match_pt_kv = {}  # 未匹配到的pt权重的name及value
    matched_ms_k = []  # 被匹配到的ms权重名称
    
    '''一般性的转换规则'''
    pt2ms = {'module': 'fcos_body',  # backbone部分
             'stem.': '',
             '.body': '',
             '.rpn': '',
             'downsample': 'down_sample_layer',

             'backbone.fpn': 'fpn',  # FPN部分
             'fpn_inner4': 'prj_5',
             'fpn_layer4': 'conv_5',

             'fpn_inner3': 'prj_4',
             'fpn_layer3': 'conv_4',

             'fpn_inner2': 'prj_3',
             'fpn_layer2': 'conv_3',

             'top_blocks.p': 'conv_out',

             'bbox_tower': 'reg_conv',  # head部分
             'cls_tower': 'cls_conv',
             'bbox_pred': 'reg_pred',

             'scales': 'scale_exp',
             'centerness': 'cnt_logits',

             "running_mean": "moving_mean",  # BN部分
             "running_var": "moving_variance",

             }

    '''BN层的特殊转换规则'''
    pt2ms_bn = {
        "weight": "gamma",
        "bias": "beta",
    }


    for i in t['model'].keys():
        pt_name = copy.deepcopy(i)
        pt_value = copy.deepcopy(t['model'][i])
        
        '''通用的处理'''
        for k, v in pt2ms.items():
            if k in pt_name:
                pt_name = pt_name.replace(k, v)
        '''BN层处理'''
        if 'bn' in pt_name:
            for k, v in pt2ms_bn.items():
                if k in pt_name:
                    pt_name = pt_name.replace(k, v)
        '''下采样层特别处理'''
        if 'down' in pt_name:
            if 'bias' in pt_name:
                pt_name = pt_name.replace('bias', 'beta')
            if 'down_sample_layer.1.weight' in pt_name:
                pt_name = pt_name.replace('weight', 'gamma')

        '''head部分的特殊处理'''
        if 'cls_conv' in pt_name or 'reg_conv' in pt_name:
            if '1' in pt_name or '4' in pt_name or '7' in pt_name or '10' in pt_name:
                pt_name = pt_name.replace('weight', 'gamma')
                pt_name = pt_name.replace('bias', 'beta')

        '''改名成功,匹配到ms中的权重了,记录'''
        if pt_name in m.keys():
            assert pt_value.shape == m[pt_name].shape
            match_pt_kv[pt_name] = pt_value
            match_pt_kv_mslist.append({'name': pt_name, 'data': ms.Tensor(pt_value.numpy(), m[pt_name].dtype)})
            matched_ms_k.append(pt_name)
        else:
            not_match_pt_kv[i + '   ' + pt_name] = pt_value

    '''打印未匹配的pt权重名称'''
    print('\n\n------------------未匹配的pt权重名称--------------------')
    for j in not_match_pt_kv.keys():
        print(j, np.array(not_match_pt_kv[j].shape))

    '''打印未被匹配到的ms权重名称'''
    print('\n\n------------------未被匹配到的ms权重名称--------------------')
    for j in m.keys():
        if j not in matched_ms_k:
            print(j, np.array(m[j].shape))
    print('end')
    return match_pt_kv_mslist

输出:

------------------未匹配的pt权重名称--------------------

------------------未被匹配到的ms权重名称--------------------
fcos_body.backbone.end_point.weight [1001 2048]
fcos_body.backbone.end_point.bias [1001]

这俩权重不参与模型forward,是冗余的。
match_pt_kv_mslist就是转换后的mindspore权重,加载后测试发现输出有很大出入,第一个原因是mindspore1.10的ops.sort算子有bug,已提交[issue]https://gitee.com/mindspore/mindspore/issues/I7EHKI),后续版本修复了,所以我升级到2.0.0版本了,其他原因就是网络实现未对齐,接下来主要讲这部分。

区别一:输入处理未对齐

MindSpore FCOS项目链接 输入处理方式就与FCOS官方pytorch实现不一样

  • offical pytorch FCOS:BGR 255 ,使用(mean=[102.9801, 115.9465, 122.7717], std=[1., 1., 1.])进行归一化
  • MindSpore FCOS:RGB, 使用(mean=[0.40789654, 0.44719302, 0.47026115], std=[0.28863828, 0.27408164, 0.27809835])进行归一化

其他的裁剪,图像padding对推理结果影响不会很大。

归一化对齐为官方实现后仍发现图片值仍有不同(B通道的最大值不一样),可能Normalize的底层实现有区别?没有深究,后续直接用torch的Normalize结果张量输入到mindspore中以实现模型输入对齐。

输入对齐后的测试:使用coco2017验证集第一张图像(val/000000000139.jpg),resize到(800,1216)大小,两个框架的模型分别输入进去,输出有差别,

进行排查,发现模型第一个卷积的padding没对齐。

区别二:第一个7x7卷积padding方式未对齐

pytorch:

torch默认pad模式
MindSpore-FCOS模型权重迁移推理对齐实录
卷积结果:
MindSpore-FCOS模型权重迁移推理对齐实录

mindspore:

same模式下的卷积跟torch的pad模式下肯定不一样,且两种框架的same也不一样:算子区别
MindSpore-FCOS模型权重迁移推理对齐实录
结果自然不一样:
MindSpore-FCOS模型权重迁移推理对齐实录
原实现:

nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=0, pad_mode='same', weight_init=weight)

改为:

nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, pad_mode='pad', weight_init=weight)

MindSpore-FCOS模型权重迁移推理对齐实录
结果这就对了。

其实发现设置mindsporefocs实现的resnet中的self.res_base=True就会调用正确的7x7卷积。

第一个卷积对了,但后面BN层就不对了,官方的BN层是一种frozenBN,没有使用eps,去除了eps按公式手动计算,但还是有误差,不知为何…

此外,mindspore实现的fcos的卷积pad_mode全选的same,这个肯定与官方的对不齐,pytorch官方的全使用的zeros模式,对应的mindspore应该是pad模式吧

FCOS对齐先放在这儿,后续再处理,已经有了一定的经验,先去做TOOD的迁移。文章来源地址https://www.toymoban.com/news/detail-499770.html

到了这里,关于MindSpore-FCOS模型权重迁移推理对齐实录的文章就介绍完了。如果您还想了解更多内容,请在右上角搜索TOY模板网以前的文章或继续浏览下面的相关文章,希望大家以后多多支持TOY模板网!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处: 如若内容造成侵权/违法违规/事实不符,请点击违法举报进行投诉反馈,一经查实,立即删除!

领支付宝红包 赞助服务器费用

相关文章

  • 论文笔记:CVPR2023 IRRA—隐式推理细粒度对齐模型,语言行人检索任务新SOTA,CUHK-PEDES数据集Rank-1可达73.38%!

    论文 :Cross-Modal Implicit Relation Reasoning and Aligning for Text-to-Image Person Retrieval 代码 :https://github.com/anosorae/IRRA 这是今年CVPR2023的工作,也是目前在语言行人检索领域实现SOTA性能的模型,模型整体并不复杂性能却很好,代码也做了开源,是一个非常好的工作。 下面将对该文章进行

    2024年02月13日
    浏览(59)
  • LLMs之llama_7b_qlora:源代码解读inference_qlora.py(模型推理)使用LORA权重来初始化预训练的LLAMA模型来进行文本生成(基于用户交互输入的上下文生成新文本)

    LLMs之llama_7b_qlora:源码解读inference_qlora.py(模型推理)使用LORA权重来初始化预训练的LLAMA模型来进行文本生成(基于用户交互输入的上下文生成新文本) 目录

    2024年02月15日
    浏览(74)
  • 手把手教你用MindSpore训练一个AI模型!

    首先我们要先了解深度学习的概念和AI计算框架的角色( https://zhuanlan.zhihu.com/p/463019160 ),本篇文章将演示怎么利用MindSpore来训练一个AI模型。和上一章的场景一致,我们要训练的模型是用来对手写数字图片进行分类的LeNet5模型 请参考( http://yann.lecun.com/exdb/lenet/ )。 图1 M

    2024年02月04日
    浏览(60)
  • 任意模型都能蒸馏,异构模型的知识蒸馏方法OFAKD已在昇思MindSpore开源

    自知识蒸馏方法在2014年被首次提出以来,其开始广泛被应用于模型压缩领域。在更强大教师模型辅助监督信息的帮助下,学生模型往往能够实现比直接训练更高的精度。然而,现有的知识蒸馏相关研究只考虑了同架构模型的蒸馏方法,而忽略了教师模型与学生模型异构的情形

    2024年02月22日
    浏览(40)
  • 使用MindSpore20.0的API快速实现深度学习模型之数据变换

    大家好,我是沐风晓月,本文是对昇思MindSpore社区的产品进行测试,测试的步骤,记录产品的使用体验和学习。 如果文章有什么需要改进的地方还请大佬不吝赐教👏👏。 🏠个人主页:我是沐风晓月 🧑个人简介:大家好,我是沐风晓月,双一流院校计算机专业😉😉 💕 座

    2024年01月25日
    浏览(37)
  • 【MindSpore易点通机器人-06】基于相似度模型实现问答匹配及推荐功能

    作者:王磊 更多精彩分享,欢迎访问和关注:https://www.zhihu.com/people/wldandan 在上一篇【MindSpore易点通机器人-05】问答数据预处理及编码,我们为大家讲述了机器人问答数据预处理及编码,本篇为大家介绍 机器人基于什么模型实现问答匹配及推荐功能 。 答案搜索的核心逻辑是

    2024年02月16日
    浏览(59)
  • 大语言模型对齐技术 最新论文及源码合集(外部对齐、内部对齐、可解释性)

    大语言模型对齐 (Large Language Model Alignment)是利用大规模预训练语言模型来理解它们内部的语义表示和计算过程的研究领域。主要目的是避免大语言模型可见的或可预见的风险,比如固有存在的幻觉问题、生成不符合人类期望的文本、容易被用来执行恶意行为等。 从必要性上来

    2024年02月05日
    浏览(41)
  • pytorch保存、加载和解析模型权重

    1、模型保存和加载          主要有两种情况:一是仅保存参数,二是保存参数及模型结构。 保存参数:          torch.save(net.state_dict()) 加载参数(加载参数前需要先实例化模型):          param = torch.load(\\\'param.pth\\\')          net.load_state_dict(param) 保存模型结构

    2024年02月16日
    浏览(44)
  • AI模型部署基础知识(一):模型权重与参数精度

    一般情况来说,我们通过收集数据,训练深度学习模型,通过反向传播求导更新模型的参数,得到一个契合数据和任务的模型。这一阶段,通常使用pythonpytorch进行模型的训练得到pth等类型文件。AI模型部署就是将在python环境中训练的模型参数放到需要部署的硬件环境中去跑,

    2024年01月20日
    浏览(50)
  • 【tips】huggingface下载模型权重的方法

    方法1:直接在Huggingface上下载,但是要fanqiang,可以git clone或者在代码中: 方法2:使用modelscope: 方法3:使用hf的镜像网站,https://hf-mirror.com/baichuan-inc 代码还是使用的huggingface那坨,但是在terminal运行代码时加上 HF_ENDPOINT=https://hf-mirror.com : 注:huggingface的镜像网站下载llam

    2024年02月08日
    浏览(77)

觉得文章有用就打赏一下文章作者

支付宝扫一扫打赏

博客赞助

微信扫一扫打赏

请作者喝杯咖啡吧~博客赞助

支付宝扫一扫领取红包,优惠每天领

二维码1

领取红包

二维码2

领红包