Pytorch:模块(Module类)

这篇具有很好参考价值的文章主要介绍了Pytorch:模块(Module类)。希望对大家有所帮助。如果存在错误或未考虑完全的地方,请大家不吝赐教,您也可以点击"举报违法"按钮提交疑问。


在 PyTorch 中,Module 是一个非常核心的概念,它是所有神经网络层和模型的基础类。torch.nn.Module 是构建所有神经网络的基类,在 PyTorch 中非常重要,因为它提供了网络的组织架构,并封装了权重、梯度的管理、模型参数的更新等功能。

PyTorch 中的 Linear 层ReLU 激活函数以及大多数其他神经网络层和函数都返回 torch.Tensor 类型的对象。这些返回的张量包含了经过相应层或函数处理后的数据。在神经网络中,数据通常以张量的形式在各个层之间流动。

一、Module类介绍

所有神经网络层和模型的基础类,自定义神经网络时对其继承。

1、主要功能

  1. 封装参数

    • Module 类在内部自动管理 层的参数。每当你在 Module 中定义一个层对象,如 self.conv1 = nn.Conv2d(...), PyTorch 自动将这些层的参数加入到模型的参数列表中。这些参数通过 module.parameters() 方法访问。模型参数(定义在模型内部的层的权重和偏置)默认 requires_grad=True
  2. 自动梯度计算

    • 每个 Module 可以使用 PyTorch 的自动微分(autograd)系统来自动计算和存储梯度。在 forward 方法执行运算时,PyTorch 会跟踪这些运算产生的所有张量,对应的梯度在调用张量的 backward()方法后自动计算。由于模型参数默认requires_grad=True,因此对这些参数的所有操作都将被进行自动梯度计算。
  3. 前向传播定义

    • 在定义自己的网络时,需要覆盖 Moduleforward() 方法。这是模型接收输入数据并返回输出的地方forward() 方法定义了模型的前向传播路径
  4. 模型保存和加载
    模型的保存和加载是在 PyTorch 中进行模型持久化和迁移学习的常用操作。模型可以保存为 .pt.pth 文件,包括其参数、优化器状态和其他任何相关的信息。

  • 保存模型:

    • 最简单的保存方法是使用 torch.save 来保存模型的 state_dict,这是一个包含模型参数的字典。
    torch.save(model.state_dict(), 'model_path.pth')
    
  • 加载模型:

    • 加载模型时,首先需要实例化模型对象,然后使用 load_state_dict() 方法加载参数。
    model = MyModel()
    model.load_state_dict(torch.load('model_path.pth'))
    
  1. 将模型移动到指定的设备:
    在 PyTorch 中,可以将模型和数据移动到不同的设备上(如 CPU 或 GPU),以支持不同的计算需求。
  • 使用 .to() 方法可以将模型移动到指定的设备:
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    model.to(device)
    
  1. 切换模型的训练和评估方式
    torch.nn.Module 提供了 .train().eval() 方法,用于切换模型的训练和评估模式。
  • 训练模式 (train):

    • 在训练模式下,所有的层都被通知模型正在训练,这对于某些特定层(如 DropoutBatchNorm)非常重要,因为它们在训练和评估时的行为不同。
    model.train()
    
  • 评估模式 (eval):

    • 评估模式用于模型测试或验证阶段,确保所有层都处于评估状态。
    model.eval()
    
  1. .parameters()方法
  • parameters() 方法返回一个迭代器,包含模型中所有的参数(通常用于传递给优化器)。
    for param in model.parameters():
        print(param.size())
    
  1. .modules()方法
  • modules() 方法返回一个迭代器,遍历模型中的所有模块(层)。这在分析模型结构或应用特定操作到每一层时非常有用。
    for module in model.modules():
        print(module)
    

2、神经网络模型使用理解

白话:
  损失函数和优化器都不是module类中的方法,而是外部的方法,但是他们都能够作用于模型的权重:由于自动微分,损失函数接收的是结果张量,因此损失函数带来的梯度会被更新给权重的梯度。而优化器接受的是module对象参数的迭代器,它能根据参数的梯度对参数进行更新。

  自定义神经网络,实际上就是定义一个,该类继承自torch.nn.Module。在对这个类进行实例化时,是使用__init__默认构造函数实例化的。实例化后得到一个神经网络对象,对该对象输入数据会被重载为输入forward函数,而forward函数就是对输入数据进行一层一层的网络层结构处理。forward函数的输出一般是对输入进行了前向传播后的结果,为了对模型参数进行训练更新,我们一般还需要定义一个损失函数;这个损失函数是torch.tensor类型的,可以调用其backward函数,进行反向传播梯度;最后定义一个优化器,进行参数权重更新(实际上这里反向传播梯度 就 相当于损失函数对权重进行求导了,改变权重的方向就是让损失更小的方向。)

反向传播并不更新参数,优化器才是用来更新参数的,反向传播只是更新梯度。 这也是为什么优化器有一个学习率。教程:张量的梯度计算


非白话 自定义神经网络的流程:

  1. 定义一个类,继承自 torch.nn.Module

    • 这个类是您自定义神经网络的基础。通过继承 torch.nn.Module,您的网络能够利用 PyTorch 提供的模块化、参数管理、梯度计算等强大功能。
  2. __init__ 方法中初始化网络层

    • 这是定义神经网络结构的地方。您可以添加诸如全连接层 (nn.Linear), 卷积层 (nn.Conv2d), 激活函数 (nn.ReLU) 等。这些层将被自动注册为模块的子项,使其参数也自动成为模型的一部分。
  3. 定义 forward 方法

    • forward 方法描述了输入数据如何通过定义的层传播。这个方法是在模型训练和评估时自动被调用的,用于前向传播计算输出。
  4. 损失函数和反向传播

    • 在训练阶段,网络输出通过一个损失函数 (loss function) 评估其与真实标签的差异。常用的损失函数有 nn.CrossEntropyLoss(用于分类任务)和 nn.MSELoss(用于回归任务)。
    • 调用损失张量的 .backward() 方法启动自动梯度计算,即反向传播。在这一过程中,PyTorch 根据损失函数自动计算每个参数的梯度,并存储在参数的 .grad 属性中。
    • 在每次迭代后,需要手动清空梯度,以便下一次迭代。如果不清空梯度,梯度会累积,导致不正确的参数更新。清空梯度:optimizer.zero_grad()
  5. 参数更新

    • 使用一个优化器(如 torch.optim.SGDtorch.optim.Adam)来调整网络参数,基于计算的梯度进行更新,以减少损失函数的值。这通常在调用 .backward() 后进行。

a.前向传播示例代码

损失函数和优化器的例子请看:神经网络训练过程代码详解

下面是一个简单的自定义 Module 的例子,定义了一个包含两个全连接层的简单神经网络。文章来源地址https://www.toymoban.com/news/detail-858656.html

import torch
import torch.nn as nn
import torch.nn.functional as F

class SimpleNet(nn.Module):
    def __init__(self):
        super(SimpleNet, self).__init__()
        # 定义第一个全连接层
        self.fc1 = nn.Linear(16, 12)
        # 定义第二个全连接层
        self.fc2 = nn.Linear(12, 10)
	
    def forward(self, x):
        # 第一个全连接层的激活函数使用ReLU
        x = F.relu(self.fc1(x))
        # 第二个全连接层的输出
        x = self.fc2(x)
        return x

# 实例化网络
net = SimpleNet()#__init__()里并不需要参数。默认构造函数不需要参数,net就是一个实例化对象。
# 创建一些随机输入数据
input = torch.randn(1, 16)

# 通过网络进行前向传播
output = net(input)#实际上直接使用对象名(),重载为:调用forward函数。
#input先经过一个nn.Linear(16,12),然后进行一次relu(),然后经过一个nn.Linear(12,10)

b.关键点

  • 继承:自定义的模型需要继承自 nn.Module
  • 超类初始化:使用 super() 初始化基类,这是在 Python 类中常见的做法,确保正确初始化父类部分。
  • 定义层:在构造函数中定义网络所需的各种层。
  • 前向传播:在 forward 方法中定义数据如何通过网络。

到了这里,关于Pytorch:模块(Module类)的文章就介绍完了。如果您还想了解更多内容,请在右上角搜索TOY模板网以前的文章或继续浏览下面的相关文章,希望大家以后多多支持TOY模板网!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处: 如若内容造成侵权/违法违规/事实不符,请点击违法举报进行投诉反馈,一经查实,立即删除!

领支付宝红包 赞助服务器费用

相关文章

  • 人工智能(pytorch)搭建模型10-pytorch搭建脉冲神经网络(SNN)实现及应用

    大家好,我是微学AI,今天给大家介绍一下人工智能(pytorch)搭建模型10-pytorch搭建脉冲神经网络(SNN)实现及应用,脉冲神经网络(SNN)是一种基于生物神经系统的神经网络模型,它通过模拟神经元之间的电信号传递来实现信息处理。与传统的人工神经网络(ANN)不同,SNN 中的

    2024年02月08日
    浏览(45)
  • 人工智能(pytorch)搭建模型17-pytorch搭建ReitnNet模型,加载数据进行模型训练与预测

    大家好,我是微学AI,今天给大家介绍一下人工智能(pytorch)搭建模型17-pytorch搭建ReitnNet模型,加载数据进行模型训练与预测,RetinaNet 是一种用于目标检测任务的深度学习模型,旨在解决目标检测中存在的困难样本和不平衡类别问题。它是基于单阶段检测器的一种改进方法,通

    2024年02月15日
    浏览(87)
  • 人工智能:Pytorch,TensorFlow,MXNET,PaddlePaddle 啥区别?

    学习人工智能的时候碰到各种深度神经网络框架:pytorch,TensorFlow,MXNET,PaddlePaddle,他们有什么区别? PyTorch、TensorFlow、MXNet和PaddlePaddle都是深度学习领域的开源框架,它们各自具有不同的特点和优势。以下是它们之间的主要区别: PyTorch是一个开源的Python机器学习库,它基

    2024年04月16日
    浏览(66)
  • 人工智能(Pytorch)搭建模型2-LSTM网络实现简单案例

     本文参加新星计划人工智能(Pytorch)赛道:https://bbs.csdn.net/topics/613989052  大家好,我是微学AI,今天给大家介绍一下人工智能(Pytorch)搭建模型2-LSTM网络实现简单案例。主要分类三个方面进行描述:Pytorch搭建神经网络的简单步骤、LSTM网络介绍、Pytorch搭建LSTM网络的代码实战 目录

    2024年02月03日
    浏览(59)
  • AI写作革命:PyTorch如何助力人工智能走向深度创新

    身为专注于人工智能研究的学者,我十分热衷于分析\\\"AI写稿\\\"与\\\"PyTorch\\\"这两项领先技术。面对日益精进的人工智能科技,\\\"AI写作\\\"已不再是天方夜谭;而\\\"PyTorch\\\"如璀璨明珠般耀眼,作为深度学习领域的尖端工具,正有力地推进着人工智能化进程。于此篇文章中,我将详细解析\\\"

    2024年04月13日
    浏览(51)
  • 人工智能(pytorch)搭建模型12-pytorch搭建BiGRU模型,利用正态分布数据训练该模型

    大家好,我是微学AI,今天给大家介绍一下人工智能(pytorch)搭建模型12-pytorch搭建BiGRU模型,利用正态分布数据训练该模型。本文将介绍一种基于PyTorch的BiGRU模型应用项目。我们将首先解释BiGRU模型的原理,然后使用PyTorch搭建模型,并提供模型代码和数据样例。接下来,我们将

    2024年02月09日
    浏览(60)
  • 人工智能(pytorch)搭建模型8-利用pytorch搭建一个BiLSTM+CRF模型,实现简单的命名实体识别

    大家好,我是微学AI,今天给大家介绍一下人工智能(pytorch)搭建模型8-利用pytorch搭建一个BiLSTM+CRF模型,实现简单的命名实体识别,BiLSTM+CRF 模型是一种常用的序列标注算法,可用于词性标注、分词、命名实体识别等任务。本文利用pytorch搭建一个BiLSTM+CRF模型,并给出数据样例,

    2024年02月09日
    浏览(57)
  • 人工智能(pytorch)搭建模型14-pytorch搭建Siamese Network模型(孪生网络),实现模型的训练与预测

    大家好,我是微学AI,今天给大家介绍一下人工智能(pytorch)搭建模型14-pytorch搭建Siamese Network模型(孪生网络),实现模型的训练与预测。孪生网络是一种用于度量学习(Metric Learning)和比较学习(Comparison Learning)的深度神经网络模型。它主要用于学习将两个输入样本映射到一个

    2024年02月11日
    浏览(117)
  • 人工智能TensorFlow PyTorch物体分类和目标检测合集【持续更新】

    1. 基于TensorFlow2.3.0的花卉识别 基于TensorFlow2.3.0的花卉识别Android APP设计_基于安卓的花卉识别_lilihewo的博客-CSDN博客 2. 基于TensorFlow2.3.0的垃圾分类 基于TensorFlow2.3.0的垃圾分类Android APP设计_def model_load(img_shape=(224, 224, 3)_lilihewo的博客-CSDN博客   3. 基于TensorFlow2.3.0的果蔬识别系统的

    2024年02月09日
    浏览(58)
  • 人工智能(pytorch)搭建模型11-pytorch搭建DCGAN模型,一种生成对抗网络GAN的变体实际应用

    大家好,我是微学AI,今天给大家介绍一下人工智能(pytorch)搭建模型11-pytorch搭建DCGAN模型,一种生成对抗网络GAN的变体实际应用,本文将具体介绍DCGAN模型的原理,并使用PyTorch搭建一个简单的DCGAN模型。我们将提供模型代码,并使用一些数据样例进行训练和测试。最后,我们将

    2024年02月08日
    浏览(67)

觉得文章有用就打赏一下文章作者

支付宝扫一扫打赏

博客赞助

微信扫一扫打赏

请作者喝杯咖啡吧~博客赞助

支付宝扫一扫领取红包,优惠每天领

二维码1

领取红包

二维码2

领红包