深度学习——神经网络参数的保存和提取

这篇具有很好参考价值的文章主要介绍了深度学习——神经网络参数的保存和提取。希望对大家有所帮助。如果存在错误或未考虑完全的地方,请大家不吝赐教,您也可以点击"举报违法"按钮提交疑问。

代码与详细注释:

Talk is cheap. Show you the code!

import torch
import matplotlib.pyplot as plt


# 造数据
x = torch.unsqueeze(torch.linspace(-1, 1, 100), dim=1)  # x data (tensor), shape=(100, 1)
y = x.pow(2) + 0.2*torch.rand(x.size())  # noisy y data (tensor), shape=(100, 1)


# 保存net1
def save():
    # 使用Sequential快速搭建神经网络
    net1 = torch.nn.Sequential(
        torch.nn.Linear(1, 10),
        torch.nn.ReLU(),
        torch.nn.Linear(10, 1)
    )
    # 优化器设置为SGD
    optimizer = torch.optim.SGD(net1.parameters(), lr=0.5)
    # 损失函数设置为MSELoss
    loss_func = torch.nn.MSELoss()

    # 训练100步
    for t in range(100):
        # 计算预测值
        prediction = net1(x)
        # 计算预测值和真实值之间的误差
        loss = loss_func(prediction, y)
        # 将梯度设置为0
        optimizer.zero_grad()
        # 误差反向传播
        loss.backward()
        # 优化器逐步优化
        optimizer.step()

    # 绘制第一张子图
    plt.figure(1, figsize=(10, 3))
    plt.subplot(131)
    plt.title('Net1')
    # 绘制原数据的散点图
    plt.scatter(x.data.numpy(), y.data.numpy())
    # 绘制回归曲线
    plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-', lw=5)

    # 方法一:保存整个网络
    torch.save(net1, 'net.pkl')
    # 方法二:只保存网络参数
    torch.save(net1.state_dict(), 'net_params.pkl')


# 提取
def restore_net():
    # 方法一:提取整个网络
    net2 = torch.load('net.pkl')
    # 使用net2进行预测
    prediction = net2(x)

    # 绘制第二张子图
    plt.subplot(132)
    # 设置子图标题
    plt.title('Net2')
    plt.scatter(x.data.numpy(), y.data.numpy())
    plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-', lw=5)


def restore_params():
    # 使用Sequential快速搭建一个和net1结构一致的网络net3
    net3 = torch.nn.Sequential(
        torch.nn.Linear(1, 10),
        torch.nn.ReLU(),
        torch.nn.Linear(10, 1)
    )

    # 方法二:加载net1的参数,提取,然后赋给net3的参数
    net3.load_state_dict(torch.load('net_params.pkl'))
    prediction = net3(x)

    # 绘制子图3
    plt.subplot(133)
    plt.title('Net3')
    plt.scatter(x.data.numpy(), y.data.numpy())
    plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-', lw=5)
    plt.show()


# net1进行预测,并且用两种方法进行保存模型
save()

# 提取整个网络到net2进行预测并绘图
restore_net()

# 提取net1的网络参数,然后赋给net3预测并绘图
restore_params()

运行结果:

深度学习——神经网络参数的保存和提取,深度学习,PyTorch,深度学习,神经网络,人工智能文章来源地址https://www.toymoban.com/news/detail-556862.html

因为网络结构和网络参数是一样的,所以训练出来的效果也是一致的!

到了这里,关于深度学习——神经网络参数的保存和提取的文章就介绍完了。如果您还想了解更多内容,请在右上角搜索TOY模板网以前的文章或继续浏览下面的相关文章,希望大家以后多多支持TOY模板网!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处: 如若内容造成侵权/违法违规/事实不符,请点击违法举报进行投诉反馈,一经查实,立即删除!

领支付宝红包 赞助服务器费用

相关文章

  • 论文笔记(五)FWENet:基于SAR图像的洪水水体提取深度卷积神经网络(CVPR)

    FWENet: a deep convolutional neural network for flood water body extraction based on SAR images 作者:Jingming Wang, Shixin Wang, Futao Wang, Yi Zhou, Zhenqing Wang, Jianwan Ji, Yibing Xiong Qing Zhao 期刊:Internation Journal of Digital Earth 日期:2022 :深度学习;洪水水体提取;SAR;鄱阳湖 原文:https://doi.org/10.1080

    2024年02月10日
    浏览(43)
  • 深度神经网络基础——深度学习神经网络基础 & Tensorflow在深度学习的应用

    Tensorflow入门(1)——深度学习框架Tesnsflow入门 环境配置 认识Tensorflow 深度学习框架Tesnsflow 线程+队列+IO操作 文件读取案例 神经网络的种类: 基础神经网络:单层感知器,线性神经网络,BP神经网络,Hopfield神经网络等 进阶神经网络:玻尔兹曼机,受限玻尔兹曼机,递归神经

    2024年02月16日
    浏览(47)
  • 竞赛 深度学习卷积神经网络垃圾分类系统 - 深度学习 神经网络 图像识别 垃圾分类 算法 小程序

    🔥 优质竞赛项目系列,今天要分享的是 深度学习卷积神经网络垃圾分类系统 该项目较为新颖,适合作为竞赛课题方向,学长非常推荐! 🧿 更多资料, 项目分享: https://gitee.com/dancheng-senior/postgraduate 近年来,随着我国经济的快速发展,国家各项建设都蒸蒸日上,成绩显著。

    2024年02月08日
    浏览(47)
  • 竞赛选题 深度学习卷积神经网络垃圾分类系统 - 深度学习 神经网络 图像识别 垃圾分类 算法 小程序

    🔥 优质竞赛项目系列,今天要分享的是 深度学习卷积神经网络垃圾分类系统 该项目较为新颖,适合作为竞赛课题方向,学长非常推荐! 🧿 更多资料, 项目分享: https://gitee.com/dancheng-senior/postgraduate 近年来,随着我国经济的快速发展,国家各项建设都蒸蒸日上,成绩显著。

    2024年02月07日
    浏览(52)
  • Python中的深度学习:神经网络与卷积神经网络

    当下,深度学习已经成为人工智能研究和应用领域的关键技术之一。作为一个开源的高级编程语言,Python提供了丰富的工具和库,为深度学习的研究和开发提供了便利。本文将深入探究Python中的深度学习,重点聚焦于神经网络与卷积神经网络的原理和应用。 深度学习是机器学

    2024年02月08日
    浏览(45)
  • 【AI】深度学习——前馈神经网络——全连接前馈神经网络

    前馈神经网络(Feedforward Neural Network,FNN)也称为多层感知器(实际上前馈神经网络由多层Logistic回归模型组成) 前馈神经网络中,各个神经元属于不同的层 每层神经元接收前一层神经元的信号,并输出到下一层 输入层:第0层 输出层:最后一层 隐藏层:其他中间层 整个网络

    2024年04月12日
    浏览(109)
  • 深度学习卷积神经网络垃圾分类系统 - 深度学习 神经网络 图像识别 垃圾分类 算法 小程序 计算机竞赛

    🔥 优质竞赛项目系列,今天要分享的是 深度学习卷积神经网络垃圾分类系统 该项目较为新颖,适合作为竞赛课题方向,学长非常推荐! 🧿 更多资料, 项目分享: https://gitee.com/dancheng-senior/postgraduate 近年来,随着我国经济的快速发展,国家各项建设都蒸蒸日上,成绩显著。

    2024年02月04日
    浏览(57)
  • 计算机竞赛 深度学习卷积神经网络垃圾分类系统 - 深度学习 神经网络 图像识别 垃圾分类 算法 小程序

    🔥 优质竞赛项目系列,今天要分享的是 深度学习卷积神经网络垃圾分类系统 该项目较为新颖,适合作为竞赛课题方向,学长非常推荐! 🧿 更多资料, 项目分享: https://gitee.com/dancheng-senior/postgraduate 近年来,随着我国经济的快速发展,国家各项建设都蒸蒸日上,成绩显著。

    2024年02月07日
    浏览(59)
  • 【机器学习】——深度学习与神经网络

    目录 引入 一、神经网络及其主要算法 1、前馈神经网络 2、感知器 3、三层前馈网络(多层感知器MLP) 4、反向传播算法 二、深度学习 1、自编码算法AutorEncoder 2、自组织编码深度网络 ①栈式AutorEncoder自动编码器 ②Sparse Coding稀疏编码 3、卷积神经网络模型(续下次) 拓展:

    2024年02月09日
    浏览(46)
  • 【机器学习】——神经网络与深度学习

    目录 引入 一、神经网络及其主要算法 1、前馈神经网络 2、感知器 3、三层前馈网络(多层感知器MLP) 4、反向传播算法 二、深度学习 1、自编码算法AutorEncoder 2、自组织编码深度网络 ①栈式AutorEncoder自动编码器 ②Sparse Coding稀疏编码 3、卷积神经网络模型(续下次) 拓展:

    2024年02月10日
    浏览(52)

觉得文章有用就打赏一下文章作者

支付宝扫一扫打赏

博客赞助

微信扫一扫打赏

请作者喝杯咖啡吧~博客赞助

支付宝扫一扫领取红包,优惠每天领

二维码1

领取红包

二维码2

领红包