pytorch对已有模型的更改(常用的操作)

这篇具有很好参考价值的文章主要介绍了pytorch对已有模型的更改(常用的操作)。希望对大家有所帮助。如果存在错误或未考虑完全的地方,请大家不吝赐教,您也可以点击"举报违法"按钮提交疑问。

本文会做经常性的更改,如有错误或者其他补充的,请各位大佬不吝指点。

如图所示为我们的示例输出的网络结构。

pytorch对已有模型的更改(常用的操作)

引入创建的模型:

import torch
import simple_module
mod = simple_module.Module()

一、模型的保存与读取

1.整个模型的保存与读取

# 保存整个模型
torch.save(mod, '../parameters/mod.pth')
# 读取整个模型
mod_load = torch.load('../parameters/mod.pth')

2.模型参数的保存与读取(以字典方式保存和读取)

# # 保存模型的参数(以字典的方式保存)
torch.save(mod.state_dict(), '../parameters/mod_parameter.pth')
# 查看保存了哪些参数
print(mod.state_dict().keys())
print(mod.state_dict()['feature.0.0.bias'])

# 读取模型的参数(以字典的方式读取)
mod.load_state_dict(torch.load('../parameters/mod_parameter.pth'))
odict_keys(['feature.0.0.weight', 'feature.0.0.bias', 'feature.0.1.weight', 
'feature.0.1.bias', 'feature.0.1.running_mean', 'feature.0.1.running_var', 
'feature.0.1.num_batches_tracked', 'feature.1.0.weight', 'feature.1.0.bias', 
'feature.1.1.weight', 'feature.1.1.bias', 'feature.1.1.running_mean', 
'feature.1.1.running_var', 'feature.1.1.num_batches_tracked', 'classifier.1.weight',
'classifier.1.bias'])
tensor([-0.1721, -0.1222,  0.1023, -0.1484, -0.0547, -0.1922, -0.0796, -0.1784,

        -0.0233, -0.0271, -0.1018,  0.1875])

二、模型更改某一层

# 模型修改某一层
mod.classifier[1] = torch.nn.Linear(in_features=3072, out_features=20, bias=True)

三、模型删除某些层

# 删除某一层,可以将该层设置为空序列
mod.classifier[1] = torch.nn.Sequential()

# 可以采用切片的方式删除,这样删除更加彻底
mod.classifier = torch.nn.Sequential(*list(mod.classifier.children())[:-1])

# 或者直接删除
mod.classifier.__delattr__('1')

四、模型添加层(貌似只能在某一个块的末尾添加,后续再查找资料,有大佬可以指点一下)

# 模型添加层
mod.classifier.add_module(name='liner', module=torch.nn.Linear(in_features=3072, out_features=100, bias=True))

五、冻结某些层,使得训练时不进行参数更行

1.冻结某一层

# 冻结某一层
mod.feature[0][0].weight.requires_grad = False

2.冻结所有的参数

# 冻结所有的参数
for param in mod.parameters():
    param.requires_grad = False

3.冻结前面某部分的参数,可先将参数名称罗列出来,然后选择一部分的参数名称,利用参数的名称进行冻结。这种方式可以任意地冻结自己想要冻结的层。

no_grad = []
for name, value in mod.named_parameters():
    # print(name)
    no_grad.append(name)
no_grad = no_grad[:-4]
for name, value in mod.named_parameters():
    if name in no_grad:
        value.requires_grad = False
    else:
        value.requires_grad = True

 4.还有一种方式,就是只冻结前面几层

i = 0
for name, value in mod.named_parameters():
    value.requires_grad = False
    i = i + 1
    if i == 4:
        break;

或者

model_parameters = model.named_parameters()
for i in range(freeze):
    name, value = next(model_parameters)
    value.requires_grad = False

这是我目前想到的一个方法,还有其他方法的请大佬不吝指点。 

无论哪种方式,都是将对应层的weight的requires_grad设置为False。

5.最后还需要给优化器设置过滤器文章来源地址https://www.toymoban.com/news/detail-416503.html

# 定义一个fliter,只传入requires_grad=True的模型参数
optimizer = optim.SGD(filter(lambda p : p.requires_grad, mod.parameters()), lr=1e-2) 

到了这里,关于pytorch对已有模型的更改(常用的操作)的文章就介绍完了。如果您还想了解更多内容,请在右上角搜索TOY模板网以前的文章或继续浏览下面的相关文章,希望大家以后多多支持TOY模板网!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处: 如若内容造成侵权/违法违规/事实不符,请点击违法举报进行投诉反馈,一经查实,立即删除!

领支付宝红包 赞助服务器费用

相关文章

  • pytorch加载模型和模型推理常见操作

    .pth文件可以保存模型的拓扑结构和参数,也可以只保存模型的参数,取决于model.save()中的参数。 使用方式1得到的.pth重构模型代码如下: 使用方式2得到的.pth重构模型代码如下: 以只保存模型参数的pth为例 loaded_dict_enc 的类型是:class ‘odict_items’(有序字典),本质还是

    2024年02月13日
    浏览(39)
  • Verilog中的force语句用来强制更改信号的值,特别适用于仿真和调试。本文将深入探讨force语句在FPGA开发中的应用和注意事项。

    Verilog中的force语句用来强制更改信号的值,特别适用于仿真和调试。本文将深入探讨force语句在FPGA开发中的应用和注意事项。 首先,我们需要了解force语句的语法。其基本格式为force [time] signal = value。其中,time是可选参数,表示在何时开始强制更改信号的值;signal是要更改的

    2024年02月12日
    浏览(71)
  • odoo继承已有视图操作

    Odoo中,tree视图和form视图是两种主要的视图类型,它们分别用于展示记录列表和详细记录表单。 在Odoo中,视图继承允许开发者在不修改原始视图的基础上增加或改变视图的结构或外观。这是通过创建包含继承指令的XML文件来实现的。 Tree视图继承 Tree视图显示记录的列表。要

    2024年03月09日
    浏览(29)
  • Python与Pytorch系列(二) 本文(1.8万字) | 解析Opencv, Matplotlib, PIL | 三者之间的转换 | 三者对JPG和PNG读取和写入 |

    点击进入专栏: 《人工智能专栏》 Python与Python | 机器学习 | 深度学习 | 目标检测 | YOLOv5及其改进 | YOLOv8及其改进 | 关键知识点 | 各种工具教程 推荐网站 : OpenCV Matplotlib Pillow opencv的基本图像类型可以和numpy数组相互转化,因此可以直接调用 torch.from_numpy(img) 将图像转换成 t

    2024年02月03日
    浏览(71)
  • 使用训练好的YOLOV5模型在已有XML标注文件中添加新类别

            训练完一个YOLOV5模型后,可以使用模型快速在已有XML标注文件中添加新类别,下面是在已有XML标注文件中添加新类别的具体脚本:  以上代码需要修改run函数中的:weights为yolov5模型路径,source为图片数据和xml标注文件所在文件夹,修改的xml也在source路径下。亲测

    2024年02月15日
    浏览(37)
  • Linux系列文章 —— vim的基本操作(误入vim退出请先按「ESC」再按:q不保存退出,相关操作请阅读本文)

    vim-操作篇 进程概念篇 进程地址空间篇 Linux,是一种免费使用和自由传播的类UNIX操作系统,是一个基于POSIX的多用户、多任务、支持多线程和多CPU的操作系统。它能运行主要的Unix工具软件、应用程序和网络协议。Linux继承了Unix以网络为核心的设计思想,是一个性能稳定的多用

    2024年02月03日
    浏览(42)
  • 【colab】谷歌colab免费服务器训练自己的模型,本文以yolov5为例介绍流程

    目录 一.前言 二.准备工作 1.注册Google drive(谷歌云盘) Google Driver官网:https://drive.google.com/drive/ Colab官网:https://colab.research.google.com/ 2.上传项目文件 3.安装Colaboratory 4.colab相关操作和命令 5.项目相关操作  三.异常处理         本文介绍了在谷歌开放平台Google colab上租用免

    2023年04月08日
    浏览(51)
  • 快慢指针该如何操作?本文带你认识快慢指针常见的三种用法及在链表中的实战

    很多同学都听过 快慢指针 这个名词,认为它不就是定义两个引用(指针)一前一后吗?是的,它的奥秘很深,它的作用究竟有哪些?究竟可以用来做哪些题目?下面我将一一带你了解和应用 下面的本节的大概内容,有疑惑的点,欢迎小伙伴们留言 目录 1.简述快慢指针 2.快慢

    2024年02月04日
    浏览(34)
  • 开箱即用的ChatGPT替代模型,还可训练自己数据

    OpenAI 是第一个在该领域取得重大进展的公司,并且使围绕其服务构建抽象变得更加容易。 然而,便利性带来了集中化、通过中介的成本、数据隐私和版权问题。 而数据主权和治理是这些新的LLM服务提供商如何处理商业秘密或敏感信息的首要问题,用户数据已被用于预训练以

    2023年04月23日
    浏览(54)
  • 计算机网络模型、网络传输、封装分用的详细讲解

    在互联网诞生之前,人们通过发电报等方式进行通信,这种方式是非常不稳定的,通信链路容易被打断。由于战争时期需要更好的通信手段,此时就将原本两点之间简单的通信链路,扩展成复杂的链路,保证就算有一条链路被打断也能进行通信,就类似于我们现在的互联网。

    2023年04月13日
    浏览(39)

觉得文章有用就打赏一下文章作者

支付宝扫一扫打赏

博客赞助

微信扫一扫打赏

请作者喝杯咖啡吧~博客赞助

支付宝扫一扫领取红包,优惠每天领

二维码1

领取红包

二维码2

领红包