pytorch 固定部分网络参数需要使用 with torch.no_grad()吗

这篇具有很好参考价值的文章主要介绍了pytorch 固定部分网络参数需要使用 with torch.no_grad()吗。希望对大家有所帮助。如果存在错误或未考虑完全的地方,请大家不吝赐教,您也可以点击"举报违法"按钮提交疑问。

在 PyTorch 中,torch.no_grad() 是一个上下文管理器,用于设置一段代码的计算图不需要梯度。具体来说,当我们在 torch.no_grad() 的上下文中执行某些操作时,PyTorch 不会为这些操作自动计算梯度,以节省计算资源。

使用 torch.no_grad() 可以有如下几种情况:

  1. 测试模型:在测试模型或部分模型时,我们不需要计算梯度,因为这些操作不会影响我们的模型的训练。此时,可以使用 torch.no_grad()。

  2. 固定模型参数:有时我们可能需要固定模型的某些参数,例如在微调(fine-tuning)预训练模型时,我们可能只需要更新一部分参数,而其他参数应该被固定下来。此时,可以使用 torch.no_grad() 来固定特定的参数。

在 PyTorch 中,固定部分网络参数不一定需要使用 torch.no_grad()。当我们将需要固定的参数的 requires_grad 属性设置为 False 时,这些参数在计算梯度时就不会被更新,因此不需要使用 torch.no_grad()。

然而,当我们在使用不需要更新的参数进行前向传递时,如果不使用 torch.no_grad(),PyTorch 会默认计算梯度,这会浪费计算资源。因此,为了节省计算资源,建议在使用不需要更新的参数进行前向传递时使用 torch.no_grad()。

下面是一个示例代码,演示如何使用 torch.no_grad() 来固定部分网络参数:

import torch
import torch.nn as nn

# 创建一个简单的网络
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.fc1 = nn.Linear(10, 5)
        self.fc2 = nn.Linear(5, 1)
        
    def forward(self, x):
       
        x = self.fc1(x)
        x = nn.functional.relu(x)
        with torch.no_grad():
            x = self.fc2(x)
        return x

# 创建输入和标签
inputs = torch.randn(3, 10)
labels = torch.tensor([[1.0], [0.0], [1.0]])

# 创建网络和优化器
net = Net()

# 前向传递计算
outputs = net(inputs)

# 在测试模型时,可以使用 torch.no_grad() 来禁用梯度计算

    # 对输出进行操作,但不需要计算梯度
outputs = nn.functional.sigmoid(outputs)

# 计算损失函数
loss = nn.functional.binary_cross_entropy_with_logits(outputs, labels)

# 反向传播计算梯度
loss.backward()


for name, param in net.named_parameters():
    print(name, param.grad)


输出:
报错: RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn

在上述代码中,我们首先定义了一个简单的 Net 网络,并使用 torch.no_grad() 来禁用了在 fc2 层的梯度计算。然后,我们对输出进行了操作,并计算了损失函数和梯度。最后,我们输出了每个参数的梯度。可以看到,由于我们在 fc2 层使用了 torch.no_grad(),因此 fc2 层的参数的梯度为 None,而 fc1 层的参数的梯度正常计算。所以由于 fc2 层的梯度为None。所以反向传播会直接报错。

我的建议:

with torch.no_grad()仅仅在测试的时候用就行,固定参数 直接requires_grad = False就可以了。

with torch.no_grad() 很容易导致反向传播的时候 某些层 无法计算梯度,导致RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn错误。

所以 with torch.no_grad()尽量放在网络前面层,不用放在最后面的层,比如上面这个例子,固定fc2,梯度会导致无法传播到fc1,导致报错。

如果修改一下 固定fc1,就没有错误

import torch
import torch.nn as nn

# 创建一个简单的网络
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.fc1 = nn.Linear(10, 5)
        self.fc2 = nn.Linear(5, 1)
        
    def forward(self, x):
        with torch.no_grad():
            x = self.fc1(x)
        x = nn.functional.relu(x)
        
        x = self.fc2(x)
        return x

# 创建输入和标签
inputs = torch.randn(3, 10)
labels = torch.tensor([[1.0], [0.0], [1.0]])

# 创建网络和优化器
net = Net()

# 前向传递计算
outputs = net(inputs)

# 在测试模型时,可以使用 torch.no_grad() 来禁用梯度计算

    # 对输出进行操作,但不需要计算梯度
outputs = nn.functional.sigmoid(outputs)

# 计算损失函数
loss = nn.functional.binary_cross_entropy_with_logits(outputs, labels)

# 反向传播计算梯度
loss.backward()


for name, param in net.named_parameters():
    print(name, param.grad)

输出:
fc1.weight None
fc1.bias None
fc2.weight tensor([[ 0.0246,  0.0000, -0.0222,  0.0000, -0.0331]])
fc2.bias tensor([-0.0125])

原因就是 fc2不计算梯度,但是fc1 需要梯度更新,梯度需要传到 fc1 ,所以报错,
但 fc1 不计算梯度,此时 fc1梯度为None 不会影响 fc2文章来源地址https://www.toymoban.com/news/detail-516530.html

到了这里,关于pytorch 固定部分网络参数需要使用 with torch.no_grad()吗的文章就介绍完了。如果您还想了解更多内容,请在右上角搜索TOY模板网以前的文章或继续浏览下面的相关文章,希望大家以后多多支持TOY模板网!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处: 如若内容造成侵权/违法违规/事实不符,请点击违法举报进行投诉反馈,一经查实,立即删除!

领支付宝红包 赞助服务器费用

相关文章

  • 详解Pytorch中的torch.nn.MSELoss函,包括对每个参数的分析!

    一、函数介绍 Pytorch中MSELoss函数的接口声明如下,具体网址可以点这里。 torch.nn.MSELoss(size_average=None, reduce=None, reduction=‘mean’) 该函数 默认用于计算两个输入对应元素差值平方和的均值 。具体地,在深度学习中,可以使用该函数用来计算两个特征图的相似性。 二、使用方式

    2023年04月19日
    浏览(47)
  • 使用PyTorch构建神经网络,并计算参数Params

    在深度学习中,模型的参数数量是一个非常重要的指标,通常会影响模型的大小、训练速度和准确度等多个方面。在本教程中,我们将介绍如何计算深度学习模型的参数数量。 本教程将以PyTorch为例,展示如何计算一个包含卷积、池化、归一化和全连接等多种层的卷积神经网

    2024年02月03日
    浏览(45)
  • 使用PyTorch构建神经网络,并使用thop计算参数和FLOPs

    FLOPs和FLOPS区别 FLOPs(floating point operations)是指浮点运算次数,通常用来评估一个计算机算法或者模型的计算复杂度。在机器学习中,FLOPs通常用来衡量神经网络的计算复杂度,因为神经网络的计算主要由矩阵乘法和卷积操作组成,而这些操作都可以转化为浮点运算次数的形式

    2024年02月03日
    浏览(44)
  • Pytorch学习:神经网络模块torch.nn.Module和torch.nn.Sequential

    官方文档:torch.nn.Module CLASS torch.nn.Module(*args, **kwargs) 所有神经网络模块的基类。 您的模型也应该对此类进行子类化。 模块还可以包含其他模块,允许将它们嵌套在树结构中。您可以将子模块分配为常规属性: training(bool) -布尔值表示此模块是处于训练模式还是评估模式。

    2024年02月10日
    浏览(43)
  • 【AI】《动手学-深度学习-PyTorch版》笔记(十六):自定义网络层、保存/加载参数、使用GPU

    自定义网络层很简单,三步即可完成 继承类:nn.Module 定义初始化函数:__init__中定义需要初始化的代码 定义向前传播函数:forward 1)定义网络层

    2024年02月13日
    浏览(47)
  • 【pytorch】torch.cdist使用说明

    torch.cdist的使用介绍如官网所示, 它是批量计算两个向量集合的距离。 其中, x1和x2是输入的两个向量集合。 p 默认为2,为欧几里德距离。 它的功能上等同于 scipy.spatial.distance.cdist (input,’minkowski’, p=p) 如果x1的shape是 [B,P,M], x2的shape是[B,R,M],则cdist的结果shape是 [B,P,R] x1一般

    2024年01月15日
    浏览(47)
  • Pytorch学习笔记(5):torch.nn---网络层介绍(卷积层、池化层、线性层、激活函数层)

     一、卷积层—Convolution Layers  1.1 1d / 2d / 3d卷积 1.2 卷积—nn.Conv2d() nn.Conv2d 1.3 转置卷积—nn.ConvTranspose nn.ConvTranspose2d  二、池化层—Pooling Layer (1)nn.MaxPool2d (2)nn.AvgPool2d (3)nn.MaxUnpool2d  三、线性层—Linear Layer  nn.Linear  四、激活函数层—Activate Layer (1)nn.Sigmoid  (

    2024年01月20日
    浏览(44)
  • PyTorch之Torch Script的简单使用

    TorchScript 简介 Torch Script Loading a TorchScript Model in C++ TorchScript 解读(一):初识 TorchScript libtorch教程(一)开发环境搭建:VS+libtorch和Qt+libtorch Torch Script 是一种序列化和优化 PyTorch 模型的格式,在优化过程中,一个 torch.nn.Module 模型会被转换成 Torch Script 的 torch.jit.ScriptModule 模

    2024年04月09日
    浏览(33)
  • 基于Pytorch的神经网络部分自定义设计

            本质上,优化和深度学习的目标是根本不同的。前者主要关注的是最小化目标,后者则关注在给定有限数据量的情况下寻找合适的模型。训练误差和泛化误差通常不同:由于优化算法的目标函数通常是基于训练数据集的损失函数,因此优化的目标是减少训练误差。但

    2024年02月10日
    浏览(32)
  • linux ubuntu20.04固定ip设置方法(静态ip)(没有以太网网络设置界面)(虚拟机的话需要设置为桥接模式)(ubuntu虚拟机固定ip地址)(VMware虚拟机)

    新买的浪潮服务器,想设置固定ip,不知咋滴,界面上没有以太网网络设置的地方,试了很多方法都不行 后来发现直接修改 /etc/netplan/ 下的配置文件,能修改成功,现把方法记录下来 首先查看服务器上以太网口,一般插上网线后,总有一个不一样的 我用 ifconfig 查看,这个网

    2024年02月03日
    浏览(77)

觉得文章有用就打赏一下文章作者

支付宝扫一扫打赏

博客赞助

微信扫一扫打赏

请作者喝杯咖啡吧~博客赞助

支付宝扫一扫领取红包,优惠每天领

二维码1

领取红包

二维码2

领红包