PyTorch框架训练的几种模型区别

这篇具有很好参考价值的文章主要介绍了PyTorch框架训练的几种模型区别。希望对大家有所帮助。如果存在错误或未考虑完全的地方,请大家不吝赐教,您也可以点击"举报违法"按钮提交疑问。

PyTorch系列文章目录



前言

在PyTorch中,.pt、.pth和.pth.tar都是用于保存训练好的模型的文件格式,它们之间的主要区别如下:

.pt文件是PyTorch 1.6及以上版本中引入的新的模型文件格式,它可以保存整个PyTorch模型,包括模型结构、模型参数以及优化器状态等信息。.pt文件是一个二进制文件,可以通过torch.save()函数来保存模型,以及通过torch.load()函数来加载模型。

.pth文件是PyTorch旧版本中使用的模型文件格式,它只保存了模型参数,没有保存模型结构和其他相关信息。.pth文件同样是一个二进制文件,可以通过torch.save()函数来保存模型参数,以及通过torch.load()函数来加载模型参数。

.pth.tar文件是一个压缩文件,它包含一个.pth文件以及其他相关信息,比如模型结构、优化器状态、超参数等。.pth.tar文件可以通过Python的标准库tarfile来解压,然后通过torch.load()函数来加载模型。

总的来说,.pt文件是最新的、最全面的模型保存格式,可以保存整个PyTorch模型,包括模型结构、参数、优化器状态等信息。.pth文件只保存了模型参数,而.pth.tar文件则是在.pth基础上加入了一些元数据信息,可以方便地保存和加载整个模型状态。在实际应用中,我们可以根据需要选择适合自己的模型保存格式。


一、.pt模型使用介绍

.pt模型文件是PyTorch框架中保存模型权重的文件格式,其结构包含以下几个部分:
Header:文件开头的一段信息,包含了PyTorch版本、模型结构等元数据信息。
State dictionary:模型的权重数据,以Python的字典形式保存。每个键对应了模型的一个参数名,值则是对应的权重矩阵或向量。
Optimizer state:如果模型使用了优化器,那么这里保存了优化器的状态信息,包括当前的学习率、动量等参数。
Other metadata:保存了一些附加的元数据信息,比如模型训练时使用的超参数、训练数据集的统计信息等。
要解读.pt模型文件的信息,可以使用PyTorch提供的torch.load()函数来加载模型文件,然后可以通过访问字典中的键值对来获取模型的权重和其他信息。例如,可以使用以下代码加载模型文件并查看模型结构和权重:

import torch
model = torch.load('model.pt')
print(model)

该代码会输出模型的结构和权重信息,可以通过访问字典中的键值对来获取具体的权重数值。例如,可以使用以下代码获取模型中名为’conv1.weight’的卷积层权重矩阵:

weights = model['conv1.weight']
print(weights)

这样就可以查看模型文件中保存的权重信息,并进一步用于模型的部署或微调等操作。

二、.pth模型使用介绍

Pytorch目前成为学术界最流行的DL框架,没有之一。很大程度上,简洁直观地操作有关。模型的保存和加载,于pytorch而言,也是很简单的。本文做了一个比较实验,方便大家理解。

首先,要清楚几个函数:torch.save,torch.load,state_dict(),load_state_dict()。
先举最简单的例子:

import torch

model = torch.load('my_model.pth')
torch.save(model, 'new_model.pth')

上面的代码非常直观,一载一存。但是有一个问题,这样保存的pth文件直接包含了整个模型的结构。当你需要灵活加载模型参数时,比如只加载部分参数,那么这种情况保存的pth文件读取进来还得额外解析出“参数文件”。

如果想更灵活对待咱们训练好的模型参数,咱们可以使用下面这个方法。pytorch把所有的模型参数用一个内部定义的dict进行保存,自称为“state_dict”。这个所谓的state_dict就是不带模型结构的模型参数了~
咱们的加载和保存就发生了一点微妙的变化:

import torch
model = MyModel() # init your model class, build the graph shape
state_dict = torch.load('model_state_dict.pth')
model.load_state_dict(state_dict)
torch.save(model.state_dict(), 'model_state_dict1.pth')

比较上面两段代码,咱们可以有一下结论:

pth文件既可能保存了模型的图结构,也有可能没保存;
加载没保存图结构的pth时,需要先初始化模型结构,即把架子搭好;
在保存模型的时候,如果不想保存图结构,可以单独保存model.state_dict()

实验
脚本如下:

import torch
import torchvision.models as models

model = models.vgg16(pretrained=True)
torch.save(model.state_dict(), 'only_weights.pth')

model_state_dict = torch.load('only_weights.pth')
model1 = models.vgg16() # describe the graph shape
model1.load_state_dict(model_state_dict)
model1.eval()

torch.save(model1, 'whole_model.pth')

model2 = torch.load('whole_model.pth')
model2.eval()

# model3 = torch.load('only_weights.pth')
# model3.eval()    # Error

model3切换到eval()模式就会报错,原因是model3只包含weights而缺乏图结构~

三、.pth.tar模型使用介绍

由于为我的特定应用程序重新训练初始模型需要大量计算资源,我想使用已经重新训练的模型。
此模型保存为 .pth.tar文件。
我希望能够首先加载这个模型。到目前为止,我已经能够弄清楚我必须使用以下内容:

model = torch.load('iNat_2018_InceptionV3.pth.tar', map_location='cpu')

这似乎有效,因为 print(model)打印出大量数字和其他值,我认为这些值是权重和偏差的值。
在此之后,我需要能够用它对图像进行分类。我一直无法弄清楚这一点。我必须如何格式化图像?图像是否应该转换为数组?在此之后,我必须如何将输入数据传递给网络?

如果您有 .pth.tar文件,您可以加载它,从而覆盖已定义模型的参数值。

这意味着保存/加载模型的一般过程如下:
编写您的网络定义(即您的 nn.Module 对象)
以您想要的方式训练或以其他方式更改网络参数
使用 torch.save 保存参数
当您想使用该网络时,请使用 nn.Module 的相同定义对象首先实例化 pytorch 网络
然后使用 torch.load 覆盖网络参数的值

这是一个超短的 mwe:

四、.pkl模型

保存

torch.save({
    'state_dict': model.state_dict(),
    'optimizer' : optimizer.state_dict(),
}, 'filename.pth.tar')

加载

checkpoint = torch.load('filename.pth.tar')
model.load_state_dict(checkpoint['state_dict'])
optimizer.load_state_dict(checkpoint['optimizer'])

总结

https://blog.csdn.net/Cretheego/article/details/128789192文章来源地址https://www.toymoban.com/news/detail-402501.html

到了这里,关于PyTorch框架训练的几种模型区别的文章就介绍完了。如果您还想了解更多内容,请在右上角搜索TOY模板网以前的文章或继续浏览下面的相关文章,希望大家以后多多支持TOY模板网!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处: 如若内容造成侵权/违法违规/事实不符,请点击违法举报进行投诉反馈,一经查实,立即删除!

领支付宝红包 赞助服务器费用

相关文章

  • PyTorch 还提供的几种连接张量的方法

    1. torch.stack() 方法: 行为: 创建一个新的维度,并在该维度上堆叠张量。结果张量的维度比原始张量多一维。 适用场景: 当需要在新的维度上堆叠张量时,特别是在创建新批次或样本时。 2. torch.cat() 与 torch.unsqueeze() 方法: 行为: 使用 torch.unsqueeze() 在现有维度上增加一个

    2024年01月17日
    浏览(22)
  • Hive的几种排序方式、区别,使用场景

    Hive 支持两种主要的排序方式: ORDER BY 和 SORT BY 。除此之外,还有 DISTRIBUTE BY 和 CLUSTER BY 语句,它们也在排序和数据分布方面发挥作用。 1. ORDER BY ORDER BY 在 Hive 中用于对查询结果进行全局排序,确保结果集是全局有序的。但是,使用 ORDER BY 时,Hive 会将所有数据集中到一个

    2024年02月02日
    浏览(31)
  • CSS中隐藏页面元素的几种方式和区别

    前言、 在平常的样式排版中,我们经常遇到将某个模块隐藏的场景,通过css隐藏的元素方法有很多种,它们看起来实现的效果是一致的,但实际上每一种方法都有一丝轻微的不同,这些不同决定了在一些特定场合下使用哪一种方法。 实现方法综合、 通过css实现隐藏元素方法

    2024年01月20日
    浏览(54)
  • 【SpringBoot系列】接收前端参数的几种方式

    前言 在现代Web开发中,前后端分离的架构已经成为主流。前端负责展示页面和用户交互,而后端则负责处理业务逻辑和数据存储。在这种架构下,前端需要将用户输入的数据发送给后端进行处理。而Spring Boot作为一种快速开发框架,提供了多种方式来接收前端数据。 本文将介

    2024年02月05日
    浏览(34)
  • 【SpringBoot系列】实现跨域的几种方式

    前言 在Web开发中,跨域是一个常见的问题。由于浏览器的同源策略,一个Web应用程序只能访问与其自身同源(即,相同协议、主机和端口)的资源。 这种策略的存在是为了保护用户的安全,防止恶意网站读取或修改用户的数据。 然而,现代Web应用程序经常需要访问不同源的

    2024年02月01日
    浏览(40)
  • VMWare虚拟机中的几种网络配置区别(桥接、仅主机、NAT)

    当在VM虚拟机上安装系统时,会提示进行虚拟机网络配置的选择操作,如下图:   那么,这几种网络连接方式的区别是什么呢? 当电脑安装好VMWare虚拟后,在本机的网络配置中,会多出来两个虚拟网卡, VMnet1 和 Vmnet8, 这两个网卡就是用于虚拟机的网络配置使用。  打开VM虚

    2024年02月09日
    浏览(31)
  • C++技能系列 ( 2 ) - const的几种使用【详解】

    C++高性能优化编程系列 深入理解软件架构设计系列 高级C++并发线程编程 C++技能系列 期待你的关注哦!!! 生活就是上帝发给你的一张手牌,无论多烂,你都得拿着。 Life is god give you a hand, no matter how bad, you have to take. (1)表示常量a,不能改变a的值 (1)表示常量引用,a代

    2024年02月11日
    浏览(27)
  • 【SpringBoot系列】读取yml文件的几种方式

    前言 在Spring Boot开发中,配置文件是非常重要的一部分,而yml文件作为一种常用的配置文件格式,被广泛应用于Spring Boot项目中。Spring Boot提供了多种方式来读取yml文件中的属性值,开发者可以根据具体的需求和场景选择合适的方式。本文将介绍Spring Boot读取yml文件的主要方式

    2024年02月05日
    浏览(30)
  • 【Mysql系列】mysql中删除数据的几种方法

    在MySQL数据库中,删除数据是一个常见的操作,它允许从表中移除不再需要的数据。在执行删除操作时,需要谨慎,以免误删重要数据。 以下是MySQL中删除数据的几种方法: DELETE语句 DROP TABLE语句 TRUNCATE TABLE 使用外键约束 DELETE语句是最常用的删除数据方法之一。它允许您根据

    2024年02月05日
    浏览(31)
  • 游戏中模型动画的几种实现方式

    游戏内动画的实现方式一般有这几种: 骨骼动画 顶点动画 材质动画 CPU蒙皮动画 骨骼动画是一种基于骨骼系统的动画技术,它通过对骨骼进行变换来控制模型的姿态和动作。 在骨骼动画中,模型通常被分解成多个部分,每个部分都与一个或多个骨骼相连,通过对骨骼进行旋

    2024年02月05日
    浏览(42)

觉得文章有用就打赏一下文章作者

支付宝扫一扫打赏

博客赞助

微信扫一扫打赏

请作者喝杯咖啡吧~博客赞助

支付宝扫一扫领取红包,优惠每天领

二维码1

领取红包

二维码2

领红包