浅谈一谈pytorch中模型的几种保存方式、以及如何从中止的地方继续开始训练;

这篇具有很好参考价值的文章主要介绍了浅谈一谈pytorch中模型的几种保存方式、以及如何从中止的地方继续开始训练;。希望对大家有所帮助。如果存在错误或未考虑完全的地方,请大家不吝赐教,您也可以点击"举报违法"按钮提交疑问。

一、本文总共介绍3中pytorch模型的保存方式:1.保存整个模型;2.只保存模型参数;3.保存模型参数、优化器、学习率、epoch和其它的所有命令行相关参数以方便从上次中止训练的地方重新启动训练过程。

1.保存整个模型。这种保存方式最简单,保存内容包括模型结构、模型参数以及其它相关信息。代码如下:

# 保存模型,PATH为模型的保存路径及模型命名
import torch
torch.save(model,PATH)

# 加载模型
model = torch.load(PATH)

2. 只保存模型参数,不保存模型结构和其它相关信息。这种方式保存的模型,在加载模型前需要构建相同的模型结构,然后再将加载的模型参数赋值给对应的层。代码如下:

# 只保存模型参数
torch.save(model.state_dict(), PATH)

# 创建相同结构的模型,然后加载模型参数
model = Model()   # 调用Model类实例化模型
model_dict = torch.load(PATH)
model.load_state_dict(model_dict) #加载模型参数

如果进行模型加载前,创建的模型结构发生了改变,和原来预训练的模型的结构不同,则需要遍历模型参数进行选择性赋值,例如下面的代码:

from collections import OrderedDict

model = Unet()  # 实例化Unet模型
model_dict = torch.load(pretrained_pth, map_location="cpu")  # 加载模型时将参数映射到CPU上
new_state_dict = OrderedDict()  # 新建一个字典类型用来存储新的模型参数
# 改变模型结构名称,如果有,就去掉backbone.前缀
for k, v in model_dict["state_dict"].items():
    new_state_dict[k.replace("backbone.", "")] = v

model.load_state_dict(new_state_dict)  # 加载模型参数

注意上述代码中,有一个参数 map_location="cpu",这个参数是指定将模型参数映射到CPU上,这个参数一般在一下情况下比较适用:1. 当你在CPU上训练了一个模型,并且想将其加载到CPU上进行推断或者继续训练时,使用map_location="cpu"可以确保模型参数被正确地映射到CPU上;2.如果你的预训练模型是在GPU上训练的,但是你在没有GPU的环境中加载模型时,使用这个参数可以避免找不到GPU而导致的错误。 而如果你的代码没有指定map_location参数,则默认情况下pytorch会尝试将模型加载到当前可用设备上(通常是GPU)

3. 保存模型必要参数,使下次训练可以从模型训练停止的地方继续训练,代码如下:

# 将需要保存的参数打包成字典类型
save_file = {"model": model.state_dict(),
             "optimizer": optimizer.state_dict(),
             "lr_scheduler": lr_scheduler.state_dict(),
             "epoch": epoch,
             "args": args}     

# 保存模型和其它参数    
torch.save(save_file, "save_weights/model.pth")
    
# 加载模型和必要的参数
checkpoint = torch.load(args.resume, map_location='cpu')
model.load_state_dict(checkpoint['model'])  # 加载模型参数
optimizer.load_state_dict(checkpoint['optimizer'])  # 加载模型优化器
lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])  # 加载模型学习策略
args.start_epoch = checkpoint['epoch'] + 1  # 加载模型训练epoch停止数

如果仅是进行模型推理,则只用加载模型参数即可,不用加载其它的东西。文章来源地址https://www.toymoban.com/news/detail-796641.html

到了这里,关于浅谈一谈pytorch中模型的几种保存方式、以及如何从中止的地方继续开始训练;的文章就介绍完了。如果您还想了解更多内容,请在右上角搜索TOY模板网以前的文章或继续浏览下面的相关文章,希望大家以后多多支持TOY模板网!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处: 如若内容造成侵权/违法违规/事实不符,请点击违法举报进行投诉反馈,一经查实,立即删除!

领支付宝红包 赞助服务器费用

相关文章

  • 谈一谈redis脑裂

    (1)一主多从架构中,主节点与客户端通信正常,主节点与哨兵、从节点连接异常,客户端仍正常写入数据 (2)哨兵判定主节点下线,重新选主 (3)原主节点与哨兵和其他节点通信恢复,成为新主节点的从节点,drop本身所有的数据,从新主节点全量copy数据 (4)原主节点

    2024年02月12日
    浏览(47)
  • 谈一谈接口测试

    我相信你一定听说过这样一句话:“测试要尽早介入,测试进行得越早,软件开发的成本就越低,就越能更好地保证软件质量。” 但是如何尽早地进入测试,作为软件测试的你,是不是也没办法说得清楚呢?其实上面那句话中的“测试”,所指的并不是测试这个人,而是指包

    2024年02月08日
    浏览(52)
  • 谈一谈扫码登录原理

      今天给大家介绍下扫码登录功能是怎么设计的。 扫码登录功能主要分为三个阶段: 待扫描、已扫描待确认、已确认 。 整体流程图如图。 下面分阶段来看看设计原理。 1、待扫描阶段 首先是待扫描阶段,这个阶段是 PC 端跟服务端的交互过程。 每次用户打开PC端登陆请求,

    2024年02月10日
    浏览(47)
  • 谈一谈缓存穿透,击穿,雪崩

    缓存穿透是指在使用缓存系统时,频繁查询一个不存在于缓存中的数据,导致这个查询每次都要通过缓存层去查询数据源,无法从缓存中获得结果。这种情况下,大量的请求会直接穿透缓存层,直接访问数据源,从而增加了系统的负载,降低了系统的性能。 通常情况下,当一

    2024年02月14日
    浏览(52)
  • 谈一谈Python中的装饰器

    1.1 何为Python中的装饰器? Python中装饰器的定义以及用途: 装饰器是一种特殊的函数,它可以接受一个函数作为参数,并返回一个新的函数。装饰器可以用来修改或增强函数的行为,而不需要修改函数本身的代码。在Python中,装饰器通常用于实现AOP(面向切面编程),例如日

    2023年04月16日
    浏览(63)
  • 谈一谈冷门的C语言爬虫

    C语言可以用来编写爬虫程序,但是相对于其他编程语言,C语言的爬虫开发可能会更加复杂和繁琐。因为C语言本身并没有提供现成的爬虫框架和库,需要自己编写网络请求、HTML解析等功能。 不过,如果你对C语言比较熟悉,也可以尝试使用C语言编写爬虫程序,这样可以更好地

    2024年02月08日
    浏览(57)
  • 【大数据面试题】007 谈一谈 Flink 背压

    一步一个脚印,一天一道面试题 (有些难点的面试题不一定每天都能发,但每天都会写) 在流式处理框架中,如果下游的处理速度,比上游的输入数据小,就会导致程序处理慢,不稳定,甚至出现崩溃等问题。 上游数据突然增大 比如数据源突然数据量增大多倍,下游处理速

    2024年02月20日
    浏览(56)
  • 谈一谈Vue怎么用extend动态创建组件

    Vue.js是一个流行的JavaScript框架,它提供了许多功能来帮助我们构建交互式Web应用程序。其中之一是使用extend方法动态创建组件。   extend方法是Vue.js提供的一个方法,它允许我们创建一个新的Vue组件构造函数。这个新的构造函数可以继承现有的组件,也可以添加新的选项。 我

    2023年04月24日
    浏览(44)
  • [轻科普]谈一谈最近手机上的2亿像素

    最近很多厂商发布了2亿像素的手机,2亿像素比较火热,如realme 11 pro + ,荣耀的honor 90 pro,以及之前小米发布的Redmi note 12 pro +。 下图为honor 90 Pro上搭载的2亿像素 ,为S5KHP3 下图为 红米上搭载的S5kHPX 2亿像素传感器。    下图为 Realme的两亿像素,S5KHP3的超级变焦版本   以上三

    2024年02月06日
    浏览(57)
  • 谈一谈如何加快android的项目的编译速度

    随着android的组件化的到来,一个项目后期功能越来越多,模块拆分的越来越多,作为android的开发的小伙伴就不得不面对运行一下android项目可能需要5,6分钟甚至10几分钟的等待期,开发时间都浪费在编译上了,你说烦不烦呢!那么怎么解决这个困境,总不能就这么一直凑合着

    2024年02月12日
    浏览(66)

觉得文章有用就打赏一下文章作者

支付宝扫一扫打赏

博客赞助

微信扫一扫打赏

请作者喝杯咖啡吧~博客赞助

支付宝扫一扫领取红包,优惠每天领

二维码1

领取红包

二维码2

领红包