一文带你搞懂PyTorch中所有模型查看的函数model.modules()系列

这篇具有很好参考价值的文章主要介绍了一文带你搞懂PyTorch中所有模型查看的函数model.modules()系列。希望对大家有所帮助。如果存在错误或未考虑完全的地方,请大家不吝赐教,您也可以点击"举报违法"按钮提交疑问。

model一般继承nn.Model 他的实例一般具有几个有序字典

_modules,_parameters,_buffers,表示当前model的子模块,自己注册的parameters和buffers

注意,_modules字典keys对应子模块名字,value对应子模块的实例,所以可以迭代的调用子模块的子模块,比如下面两个函数

model._modules["blocks"]._modules["0"]._modules["attn"]._modules["qkv"]._parameters.keys()#odict_keys(['weight', 'bias'])

model._modules["blocks"]._modules["0"]._modules["attn"]._modules["qkv"]._buffers.keys()#odict_keys(['weight_mask'])

因为是字典,所以可以用 keys() value() items()方法

比如model._modules.items()就是一个包含模型所有子模块的迭代器

 

接下来看几个model的方法

对于生成器,我们需要用循环或者next()来获取数据,或者list/dict()转化为ist/dict

什么是生成器,迭代器,可迭代对象,见

一文看懂python的迭代器和可迭代对象 - 知乎 (zhihu.com)

Python迭代器和生成器详解 - 知乎 (zhihu.com)

 

model._buffers#OrderedDict()

model.buffers()#<generator object Module.buffers at 0x7f7a80496d60>

list(model.buffers())[0].size()#torch.Size([2304, 768])

type(list(model.named_buffers())[0])#tuple

list(model.named_buffers())[0][0]#'blocks.0.attn.head_mask'

dict(model.named_buffers()).keys()

dict(model.buffers())#ValueError: dictionary update sequence element #0 has length 2304; 2 is required

len(list(model.buffers()))#12

# modules() 强制遍历

model.named_modules()/ model.modules()

model.modules()迭代遍历模型的所有子层,包括子层的子层

    def named_modules(self, memo: Optional[Set['Module']] = None, prefix: str = '', remove_duplicate: bool = True):
        r"""Returns an iterator over all modules in the network, yielding
        both the name of the module as well as the module itself.

        Args:
            memo: a memo to store the set of modules already added to the result
            prefix: a prefix that will be added to the name of the module
            remove_duplicate: whether to remove the duplicated module instances in the result
                or not

        Yields:
            (str, Module): Tuple of name and module

        Note:
            Duplicate modules are returned only once. In the following
            example, ``l`` will be returned only once.

        Example::

            >>> l = nn.Linear(2, 2)
            >>> net = nn.Sequential(l, l)
            >>> for idx, m in enumerate(net.named_modules()):
            ...     print(idx, '->', m)

            0 -> ('', Sequential(
              (0): Linear(in_features=2, out_features=2, bias=True)
              (1): Linear(in_features=2, out_features=2, bias=True)
            ))
            1 -> ('0', Linear(in_features=2, out_features=2, bias=True))

        """

        if memo is None:
            memo = set()
        if self not in memo:
            if remove_duplicate:
                memo.add(self)
            yield prefix, self
            for name, module in self._modules.items():
                if module is None:
                    continue
                submodule_prefix = prefix + ('.' if prefix else '') + name
                for m in module.named_modules(memo, submodule_prefix, remove_duplicate):
                    yield m

 前者多返回一个参数名称,这样有利于访问和初始化或修改参数

for name, layer in model.named_modules():
    if 'conv' in name:
        对layer进行处理

#当然,在没有返回名字的情形中,采用isinstance()函数也可以完成上述操作
for layer in model.modules():
    if isinstance(layer, nn.Conv2d):
        对layer进行处理

# children()只取子层

model.named_children()/model.children()

 model.children()只会遍历模型的子层,不会子层的子层遍历

    def named_children(self) -> Iterator[Tuple[str, 'Module']]:
        r"""Returns an iterator over immediate children modules, yielding both
        the name of the module as well as the module itself.

        Yields:
            (str, Module): Tuple containing a name and child module

        Example::

            >>> # xdoctest: +SKIP("undefined vars")
            >>> for name, module in model.named_children():
            >>>     if name in ['conv4', 'conv5']:
            >>>         print(module)

        """
        memo = set()
        for name, module in self._modules.items():
            if module is not None and module not in memo:
                memo.add(module)
                yield name, module

#  parameters()   只提供可优化的参数,recurse = True 默认迭代

 model.named_parameters()/model.parameters()

 迭代地返回模型的所有参数,包括自己注册的

 # buffers()   只提供不可优化的参数,recurse = True 默认迭代

 model.named_buffers()/ model.buffers()

model._buffers#OrderedDict()

model.buffers()#<generator object Module.buffers at 0x7f7a80496d60>

list(model.buffers())[0].size()#torch.Size([2304, 768])

type(list(model.named_buffers())[0])#tuple

list(model.named_buffers())[0][0]#'blocks.0.attn.head_mask'

dict(model.named_buffers()).keys()

dict(model.buffers())#ValueError: dictionary update sequence element #0 has length 2304; 2 is required

len(list(model.buffers()))#12

#model._parameters.keys()#odict_keys(['cls_token', 'pos_embed'])

    def buffers(self, recurse: bool = True) -> Iterator[Tensor]:
        for _, buf in self.named_buffers(recurse=recurse):
            yield buf

    def named_buffers(self, prefix: str = '', recurse: bool = True, remove_duplicate: bool = True) -> Iterator[Tuple[str, Tensor]]:
        r"""Returns an iterator over module buffers, yielding both the
        name of the buffer as well as the buffer itself

        """
        gen = self._named_members(
            lambda module: module._buffers.items(),
            prefix=prefix, recurse=recurse, remove_duplicate=remove_duplicate)
        yield from gen

            >>> # recurse = True 默认迭代

            >>> for name, buf in self.named_buffers():

            >>>     if name in ['running_var']:

            >>>         print(buf.size())

# state_dict字典  返回包括bufferss

model.state_dict()

model.state_dict()返回的是一个字典

包括所有参数

一个有序字典,该字典的键即为模型定义中有可学习参数的层的名称+weight或+bias,值则对应相应的权重或偏差,无参数的层则不在其中

包括para和buffers???

model.state_dict()直接返回模型的字典,和前面几个方法不同的是这里不需要迭代,它本身就是一个字典,可以直接通过修改state_dict来修改模型各层的参数,用于参数剪枝特别方便。详细的state_dict方法(24条消息) PyTorch模型保存深入理解_Ciao112的博客-CSDN博客文章来源地址https://www.toymoban.com/news/detail-738238.html

到了这里,关于一文带你搞懂PyTorch中所有模型查看的函数model.modules()系列的文章就介绍完了。如果您还想了解更多内容,请在右上角搜索TOY模板网以前的文章或继续浏览下面的相关文章,希望大家以后多多支持TOY模板网!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处: 如若内容造成侵权/违法违规/事实不符,请点击违法举报进行投诉反馈,一经查实,立即删除!

领支付宝红包 赞助服务器费用

相关文章

  • 【微信小程序】一文带你搞懂小程序的页面配置和网络数据请求

    每个小程序页面都有一个 .json 文件,该文件用来对小程序的页面进行配置。 小程序中,每个页面都有自己的.json配置文件,用来对当前页面的窗口外观、页面效果等进行配置。 小程序中,app.json中的 window 节点,可以全局配置小程序中 每个页面 的窗口表现。 如:当在app.js

    2024年02月02日
    浏览(32)
  • 【MDX】一文带你搞懂SQL Server Analysis Services 的安装和使用

    目录 Step 1: Install developer and management tools 安装 new stand-alone SQL Server installation or add the feature to an existing installation 安装 SQL Server Management Studio 安装 SSDT 安装 Visual Studio Step 2: Install databases Step 3: Install projects Step 4: 创建项目 Step 5: 定义数据源 Step 6: 部署Analysis Services项目 Step 7: F

    2023年04月08日
    浏览(38)
  • 手把手带你搞懂AMS启动原理

    彻底搞懂AMS即ActivityManagerService,看这一篇就够了 最近那么多教学视频(特别是搞车载的)都在讲AMS,可能这也跟要快速启动一个app(甚至是提高安卓系统启动速度有关),毕竟作为安卓系统的核心系统服务之一,AMS以及PMS都是很重要的,而我之前在 应用的开端–PackageManag

    2024年02月12日
    浏览(35)
  • Linux 有哪些搜索方式?5分钟带你搞懂!

    5分钟带你掌握 Linux 的三种搜索方式 1.find 命令 find 命令是用来在给定的目录下查找符合给定条件的文件 语法格式: find [查找起始路径] [查找条件] [处理动作] (1)根据名称查找: find [查找起始路径] -name 文件名 或者 find [查找起始路径] -iname 文件名 -name \\\"PATERN\\\":完全匹配文

    2024年01月16日
    浏览(30)
  • 带你搞懂人工智能、机器学习和深度学习!

    不少高校的小伙伴找我聊入门人工智能该怎么起步,如何快速入门,多长时间能成长为中高级工程师(聊下来感觉大多数学生党就是焦虑,毕业即失业,尤其现在就业环境这么差),但聊到最后,很多小朋友连人工智能和机器学习、深度学习的关系都搞不清楚。 今天更文给大

    2024年02月02日
    浏览(38)
  • 一篇文章带你搞懂前端Cookie

    浏览器Cookie相信各位点进这篇文章的小伙伴应该不陌生了,它是前端领域中一个非常重要的内容,当然也是面试的一个考点,不知道各位小伙伴是否真正掌握了Cookie呢?当然没有掌握也是没有关系的,可以跟着小编的脚步一起来学习一下前端Cookie,没有熟练掌握的小伙伴看完这

    2024年02月04日
    浏览(33)
  • 一文让你搞懂javascript如何实现继承

    一、本文想给你聊的东西包含一下几个方面:(仅限于es6之前的语法哈,因为es6里面class这用上了。。) 1.原型是啥?原型链是啥? 2.继承的通用概念。 3.Javascript实现继承的方式有哪些?   二、原型是啥?原型链是啥? 1.原型是函数本身的prototype属性。 首先js和java不

    2024年02月04日
    浏览(36)
  • 一篇文章带你搞懂stm32工程文件

    本文以stm32f4为例,讲解stm32标准库工程中各个文件的作用,学艺不精,如有错误,望大家私信或评论指出。 先看思维导图 startup_stm32f427xx.s  该文件是stm32的启动文件,由汇编语言编写,主要是做stm32上电时的配置设置(如堆栈指针,时钟数)并跳转到main函数中,执行c代码。

    2024年02月21日
    浏览(33)
  • 一篇文章带你搞懂GIT、Github、Gitee

    本文介绍了GIt,GitHub,Gitee的使用,与IDEA的Git配置,跟着文章来做你很快就能学会操作Git,利用其进行版本控制与代码托管,学习Git的使用、Git常用命令、Git分支,分支是团队协作的基础,介绍了团队内,外协作和Github远程仓库的操作、使用IDEA中的Git、IDEA中GIt的使用、在I

    2023年04月19日
    浏览(38)
  • 一篇文章带你搞懂微信小程序的开发过程

    小程序想必大家应该都不陌生了吧,今天小编带大家一起来学习下微信小程序的开发过程吧。 这个不一一介绍,网上有教程,申请成功后打开后台,我们找到小程序,下载微信开发者工具,如图: 这里我们选择普通小程序开发工具,点击微信开发者工具,如图: 然后选择相

    2024年02月09日
    浏览(23)

觉得文章有用就打赏一下文章作者

支付宝扫一扫打赏

博客赞助

微信扫一扫打赏

请作者喝杯咖啡吧~博客赞助

支付宝扫一扫领取红包,优惠每天领

二维码1

领取红包

二维码2

领红包