PyTorch翻译官网教程8-SAVE AND LOAD THE MODEL

这篇具有很好参考价值的文章主要介绍了PyTorch翻译官网教程8-SAVE AND LOAD THE MODEL。希望对大家有所帮助。如果存在错误或未考虑完全的地方,请大家不吝赐教,您也可以点击"举报违法"按钮提交疑问。

官网链接

Save and Load the Model — PyTorch Tutorials 2.0.1+cu117 documentation

保存和加载模型

在本节中,我们将了解如何通过保存、加载和运行模型预测来持久化模型状态。

import torch
import torchvision.models as models

保存和加载模型权重

PyTorch模型将学习到的参数存储在一个名为state_dict的内部状态字典中。这些可以通过torch.save 方法持久化

model = models.vgg16(weights='IMAGENET1K_V1')
torch.save(model.state_dict(), 'model_weights.pth')

输出

Downloading: "https://download.pytorch.org/models/vgg16-397923af.pth" to /var/lib/jenkins/.cache/torch/hub/checkpoints/vgg16-397923af.pth

  0%|          | 0.00/528M [00:00<?, ?B/s]
  5%|4         | 23.9M/528M [00:00<00:02, 250MB/s]
 10%|9         | 50.7M/528M [00:00<00:01, 268MB/s]
 16%|#6        | 85.5M/528M [00:00<00:01, 313MB/s]
 22%|##2       | 118M/528M [00:00<00:01, 322MB/s]
 28%|##8       | 148M/528M [00:00<00:01, 304MB/s]
 34%|###3      | 178M/528M [00:00<00:01, 292MB/s]
 39%|###8      | 206M/528M [00:00<00:01, 285MB/s]
 44%|####4     | 233M/528M [00:00<00:01, 267MB/s]
 49%|####8     | 258M/528M [00:00<00:01, 267MB/s]
 54%|#####3    | 284M/528M [00:01<00:00, 269MB/s]
 59%|#####8    | 310M/528M [00:01<00:00, 269MB/s]
 64%|######3   | 336M/528M [00:01<00:00, 270MB/s]
 69%|######8   | 362M/528M [00:01<00:00, 269MB/s]
 74%|#######3  | 388M/528M [00:01<00:00, 270MB/s]
 78%|#######8  | 414M/528M [00:01<00:00, 270MB/s]
 83%|########3 | 440M/528M [00:01<00:00, 270MB/s]
 88%|########8 | 465M/528M [00:01<00:00, 270MB/s]
 93%|#########3| 491M/528M [00:01<00:00, 270MB/s]
 98%|#########8| 517M/528M [00:01<00:00, 271MB/s]
100%|##########| 528M/528M [00:02<00:00, 276MB/s]

要加载模型权重,需要首先创建同一模型的实例,然后使用load_state_dict() 方法加载参数。

model = models.vgg16() # we do not specify ``weights``, i.e. create untrained model
model.load_state_dict(torch.load('model_weights.pth'))
model.eval()

输出

VGG(
  (features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU(inplace=True)
    (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): ReLU(inplace=True)
    (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (6): ReLU(inplace=True)
    (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (8): ReLU(inplace=True)
    (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): ReLU(inplace=True)
    (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (13): ReLU(inplace=True)
    (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (15): ReLU(inplace=True)
    (16): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (17): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (18): ReLU(inplace=True)
    (19): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (20): ReLU(inplace=True)
    (21): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (22): ReLU(inplace=True)
    (23): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (24): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (25): ReLU(inplace=True)
    (26): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (27): ReLU(inplace=True)
    (28): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (29): ReLU(inplace=True)
    (30): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (avgpool): AdaptiveAvgPool2d(output_size=(7, 7))
  (classifier): Sequential(
    (0): Linear(in_features=25088, out_features=4096, bias=True)
    (1): ReLU(inplace=True)
    (2): Dropout(p=0.5, inplace=False)
    (3): Linear(in_features=4096, out_features=4096, bias=True)
    (4): ReLU(inplace=True)
    (5): Dropout(p=0.5, inplace=False)
    (6): Linear(in_features=4096, out_features=1000, bias=True)
  )
)

注意

一定要在推理之前调用model.eval() 方法,将dropout层 和normalization 层设置成evaluation模式。如果不这样做,将产生不一致的推理结果。

保存和加载模型与形状

当加载模型权重时,我们需要首先实例化模型类,因为类定义了网络的结构。我们可能希望将该类的结构与模型一起保存,在这种情况下,我们可以使用model(而不是model.state_dict())方法保存:

torch.save(model, 'model.pth')

然后我们可以像这样加载模型:

model = torch.load('model.pth')

注意

这种方法在序列化模型时使用Python pickle模块,因此它依赖于 加载模型时可用的实际类定义。

相关教程

Saving and Loading a General Checkpoint in PyTorch


 文章来源地址https://www.toymoban.com/news/detail-571545.html


 

到了这里,关于PyTorch翻译官网教程8-SAVE AND LOAD THE MODEL的文章就介绍完了。如果您还想了解更多内容,请在右上角搜索TOY模板网以前的文章或继续浏览下面的相关文章,希望大家以后多多支持TOY模板网!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处: 如若内容造成侵权/违法违规/事实不符,请点击违法举报进行投诉反馈,一经查实,立即删除!

领支付宝红包 赞助服务器费用

相关文章

  • QGraphicsView实现简易地图4『局部加载-地图漫游』

    前文链接:QGraphicsView实现简易地图3『局部加载-地图缩放』 当鼠标拖动地图移动时,需要实时增补和删减瓦片地图,大致思路是计算地图从各方向移动时进出视口的瓦片坐标值,根据变化后的瓦片坐标值来增减地图瓦片,以下将提供实现此需求的核心代码。 1、动态演示效果

    2024年02月13日
    浏览(22)
  • elementplus实现左侧菜单栏收缩与展开

    Home.vue下包含aside.vue和menu.vue 注意: 要使用收缩与展开,el-aside必须设置 width=\\\"collapse\\\" ,否则收缩展开会出现收缩后,el-aside宽度不变窄 需要使用动态改变展开收缩值 :collapse=\\\"isCollapse\\\" @open=\\\"handleOpen\\\"展开后改变isCollapse的值(@close=\\\"handleClose\\\"不生效也不影响效果) :collapse-trans

    2024年02月10日
    浏览(24)
  • 5款软件压力测试工具分享

    一、什么是软件压力测试? 软件压力测试是一种基本的质量保证行为,它是每个重要软件测试工作的一部分。软件压力测试的基本思路很简单:不是在常规条件下运行手动或自动测试,而是在计算机数量较少或系统资源匮乏的条件下运行测试。通常要进行软件压力测试的资源

    2024年02月02日
    浏览(31)
  • Ubuntu 18.04开发环境搭建

            工作不易,为了避免未来需要重装系统的进行折腾,个人进行了Ubuntu环境配置的整合,方便自己未来能顺畅的配置好开发环境,同时分享给大家。本文多出有转载其他文,并相应的标注了转载内容,如有侵权请联系博主删除。 vmware下载: 链接:https://pan.baidu.com

    2024年02月02日
    浏览(41)
  • chatgpt赋能python:如何使用Python得到8/3的小数部分

    在数学中,8/3是一个分数,可以被表示为2.6666666666666665。然而,在Python中,我们可以使用一些技巧来得到它的小数部分。 小数部分是一个数的小数点后的部分,与整数部分相对。在数学中,我们可以使用floor和mod操作来获得一个数的整数和小数部分。 floor 操作可以将一个数向

    2024年02月08日
    浏览(25)
  • 第二章 图像基本运算及变换

    本章主要讲解图像的一些基本运算及仿射变换以及透视变换。 图像相加 imgA + imgB :当其和大于一个字节时, 大于一个字节的位数将被丢失,类似于取模。 ( A + B ) % 256 (A + B) % 256 ( A + B ) %256 cv2.add(imgA, imgB) :当数值超过 255 时,取值为 255 m i n ( A + B , 255 ) min(A + B, 255) min ( A

    2024年02月03日
    浏览(40)
  • 【图论C++】树的直径(DFS 与 DP动态规划)

    UpData Log👆 2023.9.27 更新进行中 Statement0🥇 一起进步 Statement1💯 有些描述是个人理解,可能不够标准,但能达其意 21-1-1 定义 树上 最远的两个节点之间 的距离被称为 树的直径 ,连接这两个点的路径 被称为 树的最长链 。 21-1-2 性质 1 、这两个最远点一定是叶子节点 1、这 两

    2024年02月07日
    浏览(32)
  • Axie Infinity 超级任务远超预期,和 YGG 一起探索 Web3 增长新方式!

    参与超级任务的实际人数是预期人数的两倍。 超级任务将新玩家引入 Web3 游戏领域,并向他们介绍可以为其玩家旅程提供支持的社区。 Axie Infinity 超级任务旨在向新手和 Axie Classic 老 玩家介绍「Axie Infinity|起源」这款游戏。 整个活动共吸引了 4,322 名玩家参与任务,是预期注

    2024年02月06日
    浏览(26)
  • Spring Security 6.x 系列【72】授权篇之角色分层

    有道无术,术尚可求,有术无道,止于术。 本系列Spring Boot 版本 3.1.0 本系列Spring Security 版本 6.1.0 源码地址:https://gitee.com/pearl-organization/study-spring-security-demo

    2024年01月23日
    浏览(34)
  • FPGA的主流技术与市场表现方面的调研报告

    撰写简单的FPGA的主流技术与市场表现方面的调研报告,表达自己的认知和发展展望,500字,图片,表格除外 FPGA(Field-Programmable Gate Array)是一种可编程逻辑器件,是在PAL (可编程阵列逻辑)、GAL(通用阵列逻辑)等可编程器件的基础上进一步发展的产物,广泛应用于通信、

    2024年02月07日
    浏览(32)

觉得文章有用就打赏一下文章作者

支付宝扫一扫打赏

博客赞助

微信扫一扫打赏

请作者喝杯咖啡吧~博客赞助

支付宝扫一扫领取红包,优惠每天领

二维码1

领取红包

二维码2

领红包