PyTorch模型的保存与加载

这篇具有很好参考价值的文章主要介绍了PyTorch模型的保存与加载。希望对大家有所帮助。如果存在错误或未考虑完全的地方,请大家不吝赐教,您也可以点击"举报违法"按钮提交疑问。

载入muti-GPU模型:

pretrain_model = torch.load('muti_gpu_model.pth') # 网络+权重
# 载入为single-GPU模型
gpu_model = pretrained_model.module
# 载入为CPU模型
model = ModelArch()
pretained_dict = pretained_model.module.state_dict()
model.load_satte_dict(pretained_dict)

载入muti-GPU权重:

model = ModelArch().cuda() 
model = torch.nn.DataParallel(model, device_ids=[0]) # 将model转为muti-GPU模式
checkpoint = torch.load(model_path, map_location=lambda storage, loc:storage) 
model.load_state_dict(checkpoint)
# 载入为single-GPU模型
gpu_model = model.module
# 载入为CPU模型
model = ModelArch()
model.load_state_dict(gpu_model.state_dict())
torch.save(cpu_model.state_dict(), 'cpu_model.pth')

载入CPU权重:

# 载入为CPU模型
model = ModelArch()
checkpoint = torch.load(model_path, map_location=lambda storage, loc:storage) 

# 载入为single-GPU模型
model = ModelArch().cuda() 
checkpoint = torch.load(model_path, map_location=lambda storage, loc:storage.cuda(0)) 
model.load_state_dict(checkpoint)

# 载入为muti-GPU模型
model = ModelArch().cuda() 
model = torch.nn.DataParallel(model, device_ids=[0, 1]) 
checkpoint = torch.load(model_path, map_location=lambda storage, loc:storage.cuda(0)) 
model.module.load_state_dict(checkpoint)

1. PyTorch中保存的模型文件.pth

模型保存的格式:pytorch中最常见的模型保存使用 .pt 或者是 .pth 作为模型文件扩展名,其他方式还有.t7/.pkl格式,t7文件是沿用torch7中读取模型权重的方式,而在keras中则是使用.h5文件

.pth 文件基本信息

四个键值:model(OrderedDict),optimizer(Dict),scheduler(Dict),iteration(int)

1)net["model"]   相当于net.state_dict() 返回的字典

键model所对应的值是一个OrderedDict,OrderedDict字典存储着所有的每一层的参数名称以及对应的参数值

Eg. module.backbone.body.stem.conv1.weight,参数名称很长,是因为搭建网络结构的时候采用了组件式的设计,即整个模型里面构造了一个backbone的容器组件,backbone里面又构造了一个body容器组件,body里面又构造了一个stem容器,stem里面的第一个卷积层的权重

2)net["optimizer"]    相当于optimizer.state_dict() 返回的字典

返回的是一个一般的字典 Dict 对象,这个字典只有两个key:state和param_groups

  • param_groups对应的值是一个列表;
  • state对应的值是一个字典类型,和param_groups有着对应关系,每一个元素的键值就是param_groups中每一个元素的params;

3)net["scheduler"] 返回一个字典

4)net["iteration"]  返回一个具体的数字

2. torch.save()函数:保存模型文件

注意:.pt, .pth, .pkl并不是在格式上有区别,只是后缀不同而已(仅此而已)

pytorch模型保存的两种方式:一种是保存整个模型,另一种是只保存模型的参数

torch.save(model.state_dict(), "my_model.pth") # 只保存模型的参数
torch.save(model, "my_model.pth") # 保存整个模型

保存的模型参数:一个字典类型,通过key-value的形式来存储模型的所有参数

3. torch.load()函数:用来加载torch.save()保存的模型文件

torch.load()先在CPU上加载,不会依赖于保存模型的设备。如果加载失败,可能是因为没有包含某些设备,比如在gpu上训练保存的模型,而在cpu上加载,可能会报错,此时,需要使用map_location来将存储动态重新映射到可选设备上

4. torch.nn.Module类model.state_dict()方法

state_dict 是一个简单的python的字典对象,将每一层与它的对应参数建立映射关系

注意:

1)只有参数可以训练的layer才会被保存到模型的state_dict中,如卷积层、线性层等,池化层这些本身没有参数的层没有在这个字典中;

2)作用:方便查看某一个层的权值和偏置数据;在模型保存的时候使用。

优化器对象Optimizer也有state_dict,包含了Optimizer状态以及超参数(如lr, momentum,weight_decay等)

5. torch.nn.Module类model.parameters()方法:获得模型的参数信息

  • model.parameters()方法返回的是一个生成器generator,每一个元素是从开头到结尾的参数,parameters没有对应的key,是一个由纯参数组成的generator,而state_dict是一个字典,包含了key;
  • parameters是通过named_parameters来实现的,也是Module一个与parameters类似的函数。
# 查看model的参数量:先load model的weight,然后再使用parameters()
n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)

总结:model.state_dict()、model.parameters()、model.named_parameters()这三个方法都可以查看Module的参数信息,用于更新参数,或者用于模型的保存

6. checkpoint:保存模型的参数,优化器参数,loss,epoch等

(相当于一个保存模型的文件夹)

checkpoint的机制:在模型训练的过程中,不断地保存训练结果(包括但不限于EPOCH、模型权重、优化器状态、调度器状态),即便模型训练中断,也可以基于checkpoint接续训练

在反向传播时重新计算深度神经网络的中间值(而通常情况是在前向传播时存储的),这个策略是用时间(重新计算这些值两次的时间成本)来换空间(提前存储这些值的内存成本)

7. 内存开销

神经网络使用的总内存:

  • 静态内存,尽管 PyTorch 模型中内置了一些固定开销,但总的来说几乎完全由模型权重决定
  • 模型的计算图所占用的动态内存,在训练模式下,每次通过神经网络的前向传播都为网络中的每个神经元计算一个激活值,这个值随后被存储在所谓的计算图中。必须为批次中的每个单个训练样本存储一个值,因此数量会迅速的累积起来。总成本取决于模型大小和批处理大小,并设置适用于GPU内存的最大批处理大小的限制

PyTorch 通过torch.utils.checkpoint.checkpoint和torch.utils.checkpoint.checkpoint_sequential提供梯度检查点,在前向传播时,PyTorch 将保存模型中的每个函数的输入元组。在反向传播过程中,对于每个函数,输入元组和函数的组合以实时的方式重新计算插入到每个需要它的函数的梯度公式中然后丢弃(显存中只保存输入数据和函数)。网络计算开销大致相当于每个样本通过模型前向传播开销的两倍。文章来源地址https://www.toymoban.com/news/detail-420183.html

到了这里,关于PyTorch模型的保存与加载的文章就介绍完了。如果您还想了解更多内容,请在右上角搜索TOY模板网以前的文章或继续浏览下面的相关文章,希望大家以后多多支持TOY模板网!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处: 如若内容造成侵权/违法违规/事实不符,请点击违法举报进行投诉反馈,一经查实,立即删除!

领支付宝红包 赞助服务器费用

相关文章

  • 利用pytorch自定义CNN网络(五):保存、加载自定义模型【转载】

    本文转载自: PyTorch | 保存和加载模型 本文主要介绍如何加载和保存 PyTorch 的模型。这里主要有三个核心函数: torch.save :把序列化的对象保存到硬盘。它利用了 Python 的 pickle 来实现序列化。模型、张量以及字典都可以用该函数进行保存; torch.load:采用 pickle 将反序列化的

    2024年02月13日
    浏览(41)
  • 如何用pytorch做文本摘要生成任务(加载数据集、T5 模型参数、微调、保存和测试模型,以及ROUGE分数计算)

    摘要 :如何使用 Pytorch(或Pytorchlightning) 和 huggingface Transformers 做文本摘要生成任务,包括数据集的加载、模型的加载、模型的微调、模型的验证、模型的保存、ROUGE指标分数的计算、loss的可视化。 ✅ NLP 研 0 选手的学习笔记 ● python 需要 3.8+ ● 文件相对地址 : mian.py 和 tra

    2024年02月05日
    浏览(73)
  • 深度学习技术栈 —— Pytorch中保存与加载权重文件

    权重文件是指训练好的模型参数文件,不同的深度学习框架和模型可能使用不同的权重文件格式。以下是一些常见的权重文件格式: PyTorch 的模型格式: .pt 文件。 Darknet 的模型格式: .weight 文件。 TensorFlow 的模型格式: .ckpt 文件。 一、参考文章或视频链接 [1] Navigating Mode

    2024年01月19日
    浏览(57)
  • 第五章 模型篇: 模型保存与加载

    参考教程 : https://pytorch.org/tutorials/beginner/basics/saveloadrun_tutorial.html 训练好的模型,可以保存下来,用于后续的预测或者训练过程的重启。 为了便于理解模型保存和加载的过程,我们定义一个简单的小模型作为例子,进行后续的讲解。 这个模型里面包含一个名为self.p1的Para

    2024年02月10日
    浏览(36)
  • pytorch从python转 c++涉及到的数据保存加载问题;libtorch

    python代码 c++代码 python代码 c++代码

    2024年02月13日
    浏览(46)
  • 模型加载至 cpu 和 gpu 的方式

    采用 from_pretrained 的方式,模型正常情况下,BertMoldel.from_pretrained() 是会 load 在 cpu 上的,内部 map_location 默认设置成 cpu,如果想要部署在gpu,执行下面三句话。 采用 load_state_dict 的方式加载模型,模型是部署在 哪里可以指定,如果想部署到 gpu,无需修改第一行,直接再加入

    2024年02月15日
    浏览(34)
  • Tensorflow实现训练数据的加载—模型搭建训练保存—模型调用和加载全流程

     将tensorflow的训练数据数组(矩阵)保存为.npy的数据格式。为后续的模型训练提供便捷的方法。例如如下:   加载.npy训练数据和测试数组(矩阵),加载后需要调整数据的形状以满足设计模型的输入输出需求,不然无法训练模型。 这里可以采用自定义层和tensorflow的API搭建

    2024年02月05日
    浏览(36)
  • 在pytorch中保存模型或模型参数

    在 PyTorch 中,我们可以使用 torch.save 函数将 PyTorch 模型保存到文件。这个函数接受两个参数:要保存的对象(通常是模型),以及文件路径。 在上面的例子中, model.state_dict() 用于获取模型的状态字典(包含模型的所有参数)。然后, torch.save 函数将这个状态字典保存到指定

    2024年02月05日
    浏览(36)
  • Docker保存镜像到本地并载入本地镜像文件

    目录 一、适用情况 二、镜像保存到本机  1、查看已有的镜像文件 2、将镜像保存为本地文件 保存指令一 保存指令二 测试根据镜像ID保存镜像 三、载入本地镜像 载入指令一 载入指令二 载入通过镜像ID保存的本地镜像 四、批量保存和载入镜像脚本 批量保存镜像到本地脚本

    2024年02月13日
    浏览(36)

觉得文章有用就打赏一下文章作者

支付宝扫一扫打赏

博客赞助

微信扫一扫打赏

请作者喝杯咖啡吧~博客赞助

支付宝扫一扫领取红包,优惠每天领

二维码1

领取红包

二维码2

领红包