pytorch获得模型的参数量和模型的大小

这篇具有很好参考价值的文章主要介绍了pytorch获得模型的参数量和模型的大小。希望对大家有所帮助。如果存在错误或未考虑完全的地方,请大家不吝赐教,您也可以点击"举报违法"按钮提交疑问。

参考

  • Finding model size
  • Pytorch模型中的parameter与buffer
  • What pytorch means by buffers?
  • Pytorch中Module,Parameter和Buffer的区别
  • torch.Tensor.element_size
  • torch.Tensor.nelement

buffer和parameter

在模型中,有buffer和parameter两种,其中parameter就是我们一般认为的模型的参数,它有梯度,可被训练更新。但是buffer没有梯度,不能被训练更新。
我们可以通过torch.nn.Module.buffers()torch.nn.Module.named_buffers()返回模型中的buffer。第二个函数同时返回自己定义的名称和buffer。
同时,我们可以通过torch.nn.Module.parameters()torch.nn.Module.named_parameters()返回模型中的parameter。第二个函数同时返回自己定义的名称和parameter。
两个函数都有一个bool型参数recurse,默认为true。如果为true,将递归的查找所有子层的参数。否则只查找第一层的子层。

torch.Tensor.nelement和torch.Tensor.element_size

我们得到的parameter和buffer都是Tensor类型的参数,而对于Tensor,第一个函数可以返回这个Tensor中的元素个数,比如矩阵中有多少数。第二个函数可以返回这个Tensor所对应的数据类型的字节大小。比如float32就是4字节。

获得模型的大小

def getModelSize(model):
    param_size = 0
    param_sum = 0
    for param in model.parameters():
        param_size += param.nelement() * param.element_size()
        param_sum += param.nelement()
    buffer_size = 0
    buffer_sum = 0
    for buffer in model.buffers():
        buffer_size += buffer.nelement() * buffer.element_size()
        buffer_sum += buffer.nelement()
    all_size = (param_size + buffer_size) / 1024 / 1024
    print('模型总大小为:{:.3f}MB'.format(all_size))
    return (param_size, param_sum, buffer_size, buffer_sum, all_size)

函数也很好理解,通过model.parameters()返回能迭代所有参数的迭代器,之后就能通过for循环得到所有的parameter。buffer也是类似。
返回的param_size是所有parameters的参数字节MB大小,buffer_size是所有buffer的参数字节MB大小,all_size就是模型的MB大小。文章来源地址https://www.toymoban.com/news/detail-409398.html

到了这里,关于pytorch获得模型的参数量和模型的大小的文章就介绍完了。如果您还想了解更多内容,请在右上角搜索TOY模板网以前的文章或继续浏览下面的相关文章,希望大家以后多多支持TOY模板网!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处: 如若内容造成侵权/违法违规/事实不符,请点击违法举报进行投诉反馈,一经查实,立即删除!

领支付宝红包 赞助服务器费用

相关文章

  • 解决PyTorch DDP: Finding the cause of “Expected to mark a variable ready only once“

    早上做消融实验的时候需要复现俩月前的实验结果,但是莫名其妙同样的代码和环境却跑不通了,会在loss.backward()的时候报如下错误: RuntimeError: Expected to mark a variable ready only once. This error is caused by one of the following reasons: 1) Use of a module parameter outside the ``forward`` function. Please

    2024年02月07日
    浏览(54)
  • 关于mmdetection、mmrotate如何计算参数量、计算量和速度FPS

    近几天跑完实验后,发现效果还是不错,于是开始进行模型的参数量、计算量和速度指标的计算对比,话不多说,直接上干货。 -------------------------------------------------------------------------------------------------------------------------- 首先记住一句话: 模型的参数量越小,这个模型的计

    2024年01月15日
    浏览(37)
  • 在pytorch中保存模型或模型参数

    在 PyTorch 中,我们可以使用 torch.save 函数将 PyTorch 模型保存到文件。这个函数接受两个参数:要保存的对象(通常是模型),以及文件路径。 在上面的例子中, model.state_dict() 用于获取模型的状态字典(包含模型的所有参数)。然后, torch.save 函数将这个状态字典保存到指定

    2024年02月05日
    浏览(36)
  • pytorch打印模型结构和参数

    当我们使用pytorch进行模型训练或测试时,有时候希望能知道模型每一层分别是什么,具有怎样的参数。此时我们可以将模型打印出来,输出每一层的名字、类型、参数等。 常用的命令行打印模型结构的方法有两种: 一是直接print 二是使用torchsummary库的summary 但是二者在输出

    2024年02月08日
    浏览(42)
  • PyTorch 参数化深度解析:自定义、管理和优化模型参数

    目录 torch.nn子模块parametrize parametrize.register_parametrization 主要特性和用途 使用场景 参数和参数 注意事项 示例 parametrize.remove_parametrizations 功能和用途 参数 返回值 异常 使用示例 parametrize.cached 功能和用途 如何使用 示例 parametrize.is_parametrized 功能和用途 参数 返回值 示例

    2024年01月21日
    浏览(54)
  • 使用Optuna进行PyTorch模型的超参数调优

    Optuna是一个开源的超参数优化框架,Optuna与框架无关,可以在任何机器学习或深度学习框架中使用它。本文将以表格数据为例,使用Optuna对PyTorch模型进行超参数调优。 Optuna可以使用python pip安装,如pip install Optuna。也可以使用conda install -c conda-forge Optuna,安装基于Anaconda的py

    2024年02月08日
    浏览(48)
  • 【深度学习PyTorch入门】6.Optimizing Model Parameters 优化模型参数

    现在我们有了模型和数据,是时候通过优化数据上的参数来训练、验证和测试我们的模型了。训练模型是一个迭代过程;在每次迭代中,模型都会对输出进行猜测,计算其猜测中的误差( 损失 ),收集相对于其参数的导数的误差(如我们在上一节中看到的),并使用梯度下

    2024年01月24日
    浏览(61)
  • Linux中文件大小查看和数量统计

    在 Linux 中,有多种命令可以查看磁盘分区情况,其中常用的命令如下: 命令 说明 lsblk 该命令用于显示所有块设备,包括磁盘和它们的分区。执行该命令后,会列出所有磁盘的设备名、磁盘大小、分区情况等信息 df -h 该命令用于查看所有已经挂载的文件系统的使用情况。执

    2024年02月16日
    浏览(41)
  • 【pytorch】深度学习所需算力估算:flops及模型参数量

    确定神经网络推理需要的运算能力需要考虑以下几个因素: 网络结构:神经网络结构的复杂度直接影响运算能力的需求。一般来说,深度网络和卷积网络需要更多的计算能力。 输入数据大小和数据类型:输入数据的大小和数据类型直接影响到每层神经网络的计算量和存储需

    2024年02月04日
    浏览(42)
  • 【超详细小白必懂】Pytorch 直接加载ResNet50模型和参数实现迁移学习

    Torchvision.models包里面包含了常见的各种基础模型架构,主要包括以下几种:(我们以ResNet50模型作为此次演示的例子) AlexNet VGG ResNet SqueezeNet DenseNet Inception v3 GoogLeNet ShuffleNet v2 MobileNet v2 ResNeXt Wide ResNet MNASNet 首先加载ResNet50模型,如果如果需要加载模型本身的参数,需要使用

    2024年02月16日
    浏览(48)

觉得文章有用就打赏一下文章作者

支付宝扫一扫打赏

博客赞助

微信扫一扫打赏

请作者喝杯咖啡吧~博客赞助

支付宝扫一扫领取红包,优惠每天领

二维码1

领取红包

二维码2

领红包