第五章 模型篇: 模型保存与加载

这篇具有很好参考价值的文章主要介绍了第五章 模型篇: 模型保存与加载。希望对大家有所帮助。如果存在错误或未考虑完全的地方,请大家不吝赐教,您也可以点击"举报违法"按钮提交疑问。

参考教程
https://pytorch.org/tutorials/beginner/basics/saveloadrun_tutorial.html


训练好的模型,可以保存下来,用于后续的预测或者训练过程的重启。
为了便于理解模型保存和加载的过程,我们定义一个简单的小模型作为例子,进行后续的讲解。

这个模型里面包含一个名为self.p1的Parameter和一个名为conv1的卷积层。我们没有给模型定义forward()函数,是因为暂时不需要用到该方法。假如你想使用这个模型对数据进行前向传播,会返回 “NotImplementedError: Module [Model] is missing the required “forward” function”

import torch
import torch.nn as nn
class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.t1 = torch.randn((3,2))
        self.p1 = nn.Parameter(self.t1)
        self.conv1 = nn.Conv2d(1, 1, 5)
net = Model()

pytorch中的保存与加载

首先我们来看一下pytorch中的保存和加载的方法是怎么实现的。

torch.save()

参考文档:https://pytorch.org/docs/stable/generated/torch.save.html
首先来看一下torch.save()函数。

torch.save(obj, f, pickle_module=pickle, pickle_protocol=DEFAULT_PROTOCOL, _use_new_zipfile_serialization=True)

torch.save()函数传入的第一个参数,就是我们要保存的对象,它的类别要求是object,而没有限定在nn.Module()或者nn.Parameters()等等之间。说明它可以保存的类型是多种多样的,很灵活。
传入的第二个参数是f,f是一个file-like object或者文件路径,也就是我们想要保存的位置。
后面的几个参数可以不用管它,一般也不会用到。从参数名称可以看到,我们想要保存的object是以pickle的形式保存的。因为pickle支持多种数据类型。
在源码中给了两个使用torch.save的例子。

  >>> # xdoctest: +SKIP("makes cwd dirty")
        >>> # Save to file
        >>> x = torch.tensor([0, 1, 2, 3, 4])
        >>> torch.save(x, 'tensor.pt')
        >>> # Save to io.BytesIO buffer
        >>> buffer = io.BytesIO()
        >>> torch.save(x, buffer)

第一个例子把一个tensor保存在了‘tensor.pt’中,第二个则是将tensor保存在一个buffer中。这都是允许的。

torch.load()

参考文档:https://pytorch.org/docs/stable/generated/torch.load.html#torch.load
再来看一下torch.load()函数。

torch.load(f, map_location=None, pickle_module=pickle, *, weights_only=False, **pickle_load_args)

torch.load()传入的第一个参数f对应着torch.save()中的f,它可以是一个路径,也可以是一个file-like object。
因为我们的模型训练支持cpu也支持gpu等设备,所以我们保存的object也可能处于多种设备环境中,在torch.load()时,这个object会现在CPU上进行反序列化,然后移动到其保存时所处的设备上。假如当前的系统不支持这个设备,就会出现问题,这个时候就需要使用map_location参数,这个参数可以指定你想要放置object的设备,假如没有特别指定,在设备不能实现时就会报错。
weights_only参数可以限定你先要unpickle的object的种类,在使用weights_only参数的同时,你必须明确定义pickle_moduel这个参数(默认为pickle,这也是对的),否则就会报错RuntimeError(“Can not safely load weights when explicit pickle_module is specified”。一般情况下我们也不需要管这个参数。

代码示例

给出一个简单的例子,我们将一个tensor保存在’tensor.pt’中,又使用torch.load()加载进来。
第五章 模型篇: 模型保存与加载
因为保存支持的输入是object,所以我们即使只保存一个字符串也是可以的。(可以,但没必要)
第五章 模型篇: 模型保存与加载

模型的保存与加载

保存 state_dict()

在之前的章节中有说过,调用model.state_dict()方法时,得到的返回结果是一个orderdict,这个字典的key是模型中参数的名字,value是模型的参数值。
我们通常说的保存模型,保存的就是模型的state_dict(),也就是只保存了模型的参数名和参数值,因此我们是不知道模型的正确结构和forward()中的运算顺序的,你也没有办法直接使用这个state_dict()进行预测。
现在我们保存最开始定义的笨蛋小模型的state_dict()
第五章 模型篇: 模型保存与加载
我们只保存了模型的参数名和参数值,这个’test.pth’的大小只有1.39 KB (1,428 字节)。

nn.Module().load_state_dict()

def load_state_dict(self, state_dict: Mapping[str, Any],
                        strict: bool = True):

load_state_dict()传入的参数是一个key和value的mapping。这里的keys对应的当前模型自己的state_dict的key,或者说参数名。
在使用load_state_dict()时,该方法会对传入的mapping中的key和模型本身的key进行对比。如果key可以匹配上,就会进行一些操作后,更改模型的key对应的参数值。假如没有匹配上,这个key就会被放进missing_keys或者unexpected_keys中去。
strict这个参数默认是True,所以当有不匹配的key时,就会返回报错。

加载模型参数

我们只保存的模型的参数,所以想要使用这个参数,就需要把它放置在一个现有的模型中去。比如说我们现在有一个新模型model2,它和model1有着一样的结构,但是因为初始化的随机性,它们的参数值可能是不一样的。
第五章 模型篇: 模型保存与加载
可以看到我们的model2中的参数名和model1一样,但是对应的值不一样。
我们可以使用load_state_dict()方法将model1的参数值根据参数名放到model2中去。
第五章 模型篇: 模型保存与加载
现在model1和model2中的参数值也都变得一样了。
假如我们手动修改一下我们使用torch.load()加载的state_dict,给它增加一个新的值。加载时就会报错,出现了unexpected_keys。相应地,假如给它删除一个值,就会出现Missing key(s) 的错误,在这里不举例子。

第五章 模型篇: 模型保存与加载

保存模型本身

torch.save()支持保存的对象是object,而我们的模型本身,作为nn.Module(),自然也是符合object的要求的。因此你也可以直接保存整个模型。
第五章 模型篇: 模型保存与加载
我们保存的是整个模型,包括了模型的结构和模型的参数名+参数值。这个’test2.pth’的大小是2.39 KB (2,457 字节)。

加载模型本身

我们在上面将整个模型都保存在了’test2.pth’中,因此我们使用torch.load('test2.pth)时,获得的结果就是模型本身,它的类型是nn.Module()。
第五章 模型篇: 模型保存与加载

checkpoint

保存与读取

假如我们现在有一个保存好的模型’model.pth’,我们想要继续当前模型的状态继续训练。这个时候我们就会发现,'model.pth’中拥有我们模型的参数名和参数值,但是随着我们之前的训练的进行,我们使用的optimizer或者lr_scheluder的状态我们是无法获取的,它们中也有一些参数可能在训练时发生了变化。
因此为了帮助我们重启训练状态,我们需要保存更多的信息,而不是只保存一个模型的state_dict。这些被保存的信息,统称为checkpoint。
在保存checkpoint时,我们同样使用torch.save()方法,在加载时,也是用torch.load()方法。因为torch.save支持保存各种格式,我们可以将想要保存的信息按照key和value组成一个dict,并将这个dict保存下来。
在下面这个例子中,被保存下来的信息包括当前的epoch数,模型的state_dict, 优化器的state_dict还有louss。

# Additional information
torch.save({
            'epoch': EPOCH,
            'model_state_dict': net.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': LOSS,
            }, PATH)

在加载时,我们只要按照key取其中的value就可以。

# Additional information
model = Net()
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

checkpoint = torch.load(PATH)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
loss = checkpoint['loss']

多个模型的保存与读取

我们已经知道可以将key和value对应的dict保存成checkpoint的形式,帮助我们重启训练状态。当我们有多个模型时,只不过是增加了要保存到信息而已,方法是一样的。

# Specify a path to save to
PATH = "model.pt"

torch.save({
            'modelA_state_dict': netA.state_dict(),
            'modelB_state_dict': netB.state_dict(),
            'optimizerA_state_dict': optimizerA.state_dict(),
            'optimizerB_state_dict': optimizerB.state_dict(),
            }, PATH)

在这个checkpoint中,我们分别保存了modelA和modelB的state_dict,和它们对应的优化器optimizerA和optimizerB的state_dict。
因此在使用时,只要分别放置到对应的object中就可以。文章来源地址https://www.toymoban.com/news/detail-495398.html

modelA = Net()
modelB = Net()
optimModelA = optim.SGD(modelA.parameters(), lr=0.001, momentum=0.9)
optimModelB = optim.SGD(modelB.parameters(), lr=0.001, momentum=0.9)

checkpoint = torch.load(PATH)
modelA.load_state_dict(checkpoint['modelA_state_dict'])
modelB.load_state_dict(checkpoint['modelB_state_dict'])
optimizerA.load_state_dict(checkpoint['optimizerA_state_dict'])
optimizerB.load_state_dict(checkpoint['optimizerB_state_dict'])

modelA.eval()
modelB.eval()
# - or -
modelA.train()
modelB.train()

到了这里,关于第五章 模型篇: 模型保存与加载的文章就介绍完了。如果您还想了解更多内容,请在右上角搜索TOY模板网以前的文章或继续浏览下面的相关文章,希望大家以后多多支持TOY模板网!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处: 如若内容造成侵权/违法违规/事实不符,请点击违法举报进行投诉反馈,一经查实,立即删除!

领支付宝红包 赞助服务器费用

相关文章

  • 利用pytorch自定义CNN网络(五):保存、加载自定义模型【转载】

    本文转载自: PyTorch | 保存和加载模型 本文主要介绍如何加载和保存 PyTorch 的模型。这里主要有三个核心函数: torch.save :把序列化的对象保存到硬盘。它利用了 Python 的 pickle 来实现序列化。模型、张量以及字典都可以用该函数进行保存; torch.load:采用 pickle 将反序列化的

    2024年02月13日
    浏览(41)
  • 《python语言程序设计基础》(第二版)第五章课后习题参考答案

    第五章 函数和代码的复用 5.1 改造练习题3.5,输出更大的田字格 5.2 实现isOdd函数 5.3 实现isNum函数 5.4 实现multi函数 5.5 实现isPrime函数 5.6 输出10种生日日期格式 代码一: 代码二: 5.7 汉诺塔 注:上述代码仅供参考,若有问题可在评论区留言!

    2024年02月01日
    浏览(51)
  • 如何用pytorch做文本摘要生成任务(加载数据集、T5 模型参数、微调、保存和测试模型,以及ROUGE分数计算)

    摘要 :如何使用 Pytorch(或Pytorchlightning) 和 huggingface Transformers 做文本摘要生成任务,包括数据集的加载、模型的加载、模型的微调、模型的验证、模型的保存、ROUGE指标分数的计算、loss的可视化。 ✅ NLP 研 0 选手的学习笔记 ● python 需要 3.8+ ● 文件相对地址 : mian.py 和 tra

    2024年02月05日
    浏览(73)
  • 第五章 Django 数据模型系统(基本使用)

    第一章 Django 基本使用 第二章 Django URL路由系统 第三章 Django 视图系统 第四章 Django 模板系统 第五章 Django 数据模型系统(基本使用) 第六章 Django 数据模型系统(多表操作) 第七章 Django 用户认证与会话技术 第八章 Django CSRF防护 静态网站和动态网站是两种不同类型的网站,它们

    2024年02月04日
    浏览(43)
  • 第五章 数据分析模型 题目学习(40%)

    主成分的计算步骤:1、主成分建模,标准化处理。2、计算特征根、特征向量。3、选取主成分个数。  选择B,依次递减。  相关系数和关联矩阵都做了标准化,做完标准化后方差就不会造成影响,所以选A。  A可以进行判断,虽然没讲过但是可以。BC是正常概念。D没说过。

    2024年02月09日
    浏览(41)
  • 姜启源数学模型第五版第五章火箭发射升空

    首先先简单的介绍数学建模是一个怎么样的内容 数学建模是一种将数学方法和技术应用于实际问题解决的过程。它通过建立数学模型来 描述、分析和预测现实世界中的各种问题 . 数学建模的内容可以包括以下几个方面: 1. 问题定义(问题重述) :明确问题的目标、约束条件和

    2024年02月11日
    浏览(37)
  • 【RabbitMQ教程】第五章 —— RabbitMQ - 死信队列

                                                                       💧 【 R a b b i t M Q 教 程 】 第 五 章 — — R a b b i t M Q − 死 信 队 列 color{#FF1493}{【RabbitMQ教程】第五章 —— RabbitMQ - 死信队列} 【 R a b b i t M Q 教 程 】 第 五 章 — — R a

    2024年02月09日
    浏览(40)
  • PyTorch高级教程:自定义模型、数据加载及设备间数据移动

    在深入理解了PyTorch的核心组件之后,我们将进一步学习一些高级主题,包括如何自定义模型、加载自定义数据集,以及如何在设备(例如CPU和GPU)之间移动数据。 虽然PyTorch提供了许多预构建的模型层,但在某些情况下,你可能需要自定义模型层。这可以通过继承 torch.nn.Mo

    2024年02月14日
    浏览(32)
  • 第五章:AI大模型的性能评估5.2 评估方法

    随着AI技术的发展,大型AI模型已经成为了研究和实际应用中的重要组成部分。为了确保这些模型的性能和可靠性,性能评估是一个至关重要的环节。在本章中,我们将讨论AI大模型性能评估的核心概念、算法原理、最佳实践以及实际应用场景。 在AI领域,性能评估是指评估模

    2024年02月22日
    浏览(41)
  • 朱长江《偏微分方程简明教程》答案第五章部分(2)

    5.2.1 1、设 L u = u t −

    2024年02月06日
    浏览(38)

觉得文章有用就打赏一下文章作者

支付宝扫一扫打赏

博客赞助

微信扫一扫打赏

请作者喝杯咖啡吧~博客赞助

支付宝扫一扫领取红包,优惠每天领

二维码1

领取红包

二维码2

领红包