pytorch中nn.Parameter()使用方法

这篇具有很好参考价值的文章主要介绍了pytorch中nn.Parameter()使用方法。希望对大家有所帮助。如果存在错误或未考虑完全的地方,请大家不吝赐教,您也可以点击"举报违法"按钮提交疑问。

对于nn.Parameter()是pytorch中定义可学习参数的一种方法,因为我们在搭建网络时,网络中会存在一些矩阵,这些矩阵内部的参数是可学习的,也就是可梯度求导的。

对于一些常用的网络层,例如nn.Conv2d()卷积层nn.LInear()线性层nn.LSTM()循环网络层等,这些网络层在pytorch中的nn模块中已经定义好,所以我们搭建模型时可以直接使用,但是有些自定义网络在pytorch中是没有实现的,我们就需要自定义可学习参数,那就用到了nn.Parameter()这个函数。

该函数会为我们创建一个矩阵,该矩阵是默认可梯度求导的,之后我们就可以利用这个矩阵进行计算,该函数需要传入的参数是一个tensor,一般我们会传入一个初始化好的tensor。

下面我们将使用一个简单的线性层作为实例,来理解如何使用nn.Parameter()。

一、nn.Linear()定义参数

在类中我们定义了一个线性层,输入维度是10,输出维度是3,对于nn.Linear()层内部已经封装好了nn.Parameter(),所以不需要我们自定义,直接使用即可。

class Net1(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear = nn.Linear(10, 3)
    
    def forward(self, x):
        return F.sigmoid(self.linear(x))

二、nn.Parameter()定义参数

对于一个线性层,我们会需要两个矩阵,分别是权重W和偏置b,所以我们要用nn.Parameter()定义两个可学习参数,然后传入对应维度的tensor作为参数,之后就可以在forward中定义计算过程。

class Net2(nn.Module):
    def __init__(self):
        super().__init__()
        self.W = nn.Parameter(torch.randn(10, 3))
        self.b = nn.Parameter(torch.randn(3))
    
    def forward(self, x):
        return F.sigmoid(self.W @ x + self.b)

三、查看可学习参数

利用下面代码就可以看定义好的模型中的参数文章来源地址https://www.toymoban.com/news/detail-537686.html

model1 = Net1()
model2 = Net2()

for name, parameters in model1.named_parameters():
    print(name, ':', parameters.size())
    
for name, parameters in model2.named_parameters():
    print(name, ':', parameters.size())
linear.weight : torch.Size([3, 10])
linear.bias : torch.Size([3])
W : torch.Size([10, 3])
b : torch.Size([3])

到了这里,关于pytorch中nn.Parameter()使用方法的文章就介绍完了。如果您还想了解更多内容,请在右上角搜索TOY模板网以前的文章或继续浏览下面的相关文章,希望大家以后多多支持TOY模板网!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处: 如若内容造成侵权/违法违规/事实不符,请点击违法举报进行投诉反馈,一经查实,立即删除!

领支付宝红包 赞助服务器费用

相关文章

  • 【Pytorch:nn.Embedding】简介以及使用方法:用于生成固定数量的具有指定维度的嵌入向量embedding vector

    首先我们讲解一下关于嵌入向量embedding vector的概念 1) 在自然语言处理NLP领域,是将单词、短语或其他文本单位映射到一个固定长度的实数向量空间中 。嵌入向量具有较低的维度,通常在几十到几百维之间,且每个维度都包含一定程度上的语义信息。这意味着在嵌入向量空

    2024年02月12日
    浏览(23)
  • nn.Parameter()

    nn.Parameter() 是 PyTorch 中的一个类,用于创建可训练的参数(权重和偏置),这些参数会在模型训练过程中自动更新。 nn.Parameter() 具有以下特点: nn.Parameter() 继承自 torch.Tensor ,因此它本质上也是一个张量(tensor),可以像普通张量一样进行各种张量操作,例如加法、乘法、索

    2024年02月09日
    浏览(19)
  • TypeError: cannot assign ‘torch.cuda.FloatTensor‘ as parameter ‘bias‘ (torch.nn.Parameter or None ex

    报错定位到的位置是在: self.bias = self.bias.cuda() 意为将把bias转到gpu上报错; 网上查询了很多问题都没解决,受到这篇博客的启发;pytorch 手动设置参数变量 并转到cuda上_XiaoPangJix1的博客-CSDN博客 原因可能是:bias是torch.nn.Parameter(),转移到cuda上失败,提示此报错; 其实根本原因

    2024年02月16日
    浏览(44)
  • Pytorch基本概念和使用方法

    目录 1 Adam及优化器optimizer(Adam、SGD等)是如何选用的? 1)Momentum 2)RMSProp 3)Adam 2 Pytorch的使用以及Pytorch在以后学习工作中的应用场景。 1)Pytorch的使用 2)应用场景 3 不同的数据、数据集加载方式以及加载后各部分的调用处理方式。如DataLoder的使用、datasets内置数据集的使

    2024年02月07日
    浏览(33)
  • PyTorch中grid_sample的使用方法

    官方文档 首先Pytorch中grid_sample函数的接口声明如下: input : 输入tensor, shape为 [N, C, H_in, W_in] grid: 一个field flow, shape为[N, H_out, W_out, 2],最后一个维度是每个grid(H_out_i, W_out_i)在input的哪个位置的邻域去采点。数值范围被归一化到[-1,1]。 这里的input和output就是输入的图片,或

    2024年02月08日
    浏览(21)
  • PyTorch多GPU训练模型——使用单GPU或CPU进行推理的方法

    PyTorch提供了非常便捷的多GPU网络训练方法: DataParallel 和 DistributedDataParallel 。在涉及到一些复杂模型时,基本都是采用多个GPU并行训练并保存模型。但在推理阶段往往只采用单个GPU或者CPU运行。这时怎么将多GPU环境下保存的模型权重加载到单GPU/CPU运行环境下的模型上成了一

    2024年02月09日
    浏览(36)
  • Pytorch常用的函数(二)pytorch中nn.Embedding原理及使用

    图像数据表达不需要特殊的编码,并且有天生的顺序性和关联性,近似的数字会被认为是近似的特征。 正如图像是由像素组成,语言是由词或字组成,可以把语言转换为词或字表示的集合。 然而,不同于像素的大小天生具有色彩信息,词的数值大小很难表征词的含义。最初

    2024年02月09日
    浏览(32)
  • PyTorch 之 简介、相关软件框架、基本使用方法、tensor 的几种形状和 autograd 机制

    本文参加新星计划人工智能(Pytorch)赛道:https://bbs.csdn.net/topics/613989052 PyTorch 是一个基于 Torch 的 Python 开源机器学习库,用于自然语言处理等应用程序。它主要由 Facebook 的人工智能小组开发,不仅能够实现强大的 GPU 加速,同时还支持动态神经网络,这一点是现在很多主流框架

    2024年01月18日
    浏览(42)
  • pytorch(6)——神经网络基本骨架nn.module的使用

    torch.nn(Neural network, 神经网络)内包含Pytorch神经网络框架 Containers: 容器 Convolution Layers: 卷积层 Pooling Layers: 池化层 Padding Layers: 填充层 Non-linear Activations (weighted sum, nonlinearity):非线性激活 Non-linear Activations (other):非线性激活 Normalization Layers:归一化层 Recurrent Layers:递归层 Tr

    2024年02月14日
    浏览(30)
  • Required request parameter ‘name‘ for method parameter type String is not present 报错解决方法

    注解 支持的类型 支持的请求类型 支持的  Content-Type 请求示例 @PathVariable url GET 所有 /test/{id} @RequestParam url GET 所有 /test?id=1 @RequestBody Body POST/PUT/DELETE/PATCH json {    \\\"id\\\" : 1 }      

    2024年02月11日
    浏览(36)

觉得文章有用就打赏一下文章作者

支付宝扫一扫打赏

博客赞助

微信扫一扫打赏

请作者喝杯咖啡吧~博客赞助

支付宝扫一扫领取红包,优惠每天领

二维码1

领取红包

二维码2

领红包