目录
1. 基于Module构建自己的网络
2. Module的初始化变量
3. Modules中需要子类 forward()
4. Modules中其他内置函数
1. 基于Module构建自己的网络
torch.nn.Module是所有神经网络模块的基类,如何定义自已的网络:
- 由于 Module 是神经网络模块的基类,自己的模型应该要继承这个类
- 要实现 torch.nn.Module 中的forward函数,从而进行网络的前向传播
- 一般把网络中具有可学习参数的层放在构造函数__init__()中
- 把不具有可学习参数的层(如ReLU)放在forward中,并通过nn.functional来代替
import torch.nn as nn
import torch.nn.functional as F
class Model(nn.Module): # 继承nn.Module类
def __init__(self):
super(Model, self).__init__()
# 把具有可学习参数的层放在构造函数中
self.conv1 = nn.Conv2d(1, 20, 5)
self.conv2 = nn.Conv2d(20, 20, 5)
def forward(self, x): # 实现forward函数
# 在forward中直接使用torch.nn.functional.relu()函数
x = F.relu(self.conv1(x))
return F.relu(self.conv2(x))
2. Module的初始化变量
nn.Module 类内置了一些初始化变量。包括在模块 forward、 backward 和权重加载等时候会被调用的的 hooks,也定义了 parameters 和 buffers,如源码所示:
def __init__(self):
"""
Initializes internal Module state, shared by both nn.Module and ScriptModule.
"""
torch._C._log_api_usage_once("python.nn_module")
self.training = True # 当前训练/测试的状态
self._parameters = OrderedDict() # 在训练过程中会随 BP 而更新的参数
self._buffers = OrderedDict() # 在训练过程中不会随 BP 而更新的参数
self._non_persistent_buffers_set = set()
self._backward_hooks = OrderedDict() # Backward 完成后会被调用的 hook
self._is_full_backward_hook = None
self._forward_hooks = OrderedDict() # Forward 完成后会被调用的 hook
self._forward_pre_hooks = OrderedDict() # Forward 前会被调用的 hook
self._state_dict_hooks = OrderedDict() # 得到 state_dict 以后会被调用的 hook
self._load_state_dict_pre_hooks = OrderedDict() # 加载state_dict 前会被调用的 hook
self._modules = OrderedDict() # 网络的子模块
3. Modules中需要子类 forward()
注意:在网络训练过程中,直接通过mode(input) 自动调用forward函数,而非model.forward(input)进行调用,因为mode(input)会自动调用self.__call__,接下来这些 hooks 在模块被调用时候的执行顺序如下图所示:
主要顺序如下:
- 执行_forward_pre_hooks
- 再调用 forward
- 执行_forward_hooks
- 执行_backward_hooks
4. Modules中其他内置函数
除了初始化的成员变量之外,Modules还内置了很多函数,具体包含以下几类:
(1) 属性访问:modules(), named_modules(), buffers(), named_buffers(), children(), named_children(), parameters(), named_parameters() 等
(2) 属性设置:register_parameter(),register_buffer(),register_forward_hook(),register_forward_pre_hook() 等
(3) 参数转换/转移:cpu(), cuda(), float(), double() 等
(4) 状态转换:train(), eval() 等
对于这些内置函数的详细介绍,在 PyTorch系列相关文章-Aaron_neil的csdn博客 持续更新中!
本文所参考的部分博客:
[1] pytorch 入坑三:nn module - 知乎
[2] torch.nn.Module模块简单介绍_allan2222的博客-CSDN博客文章来源:https://www.toymoban.com/news/detail-415757.html
[3] PyTorch 源码解读之 nn.Module详解_OpenMMLab的博客-CSDN博客文章来源地址https://www.toymoban.com/news/detail-415757.html
到了这里,关于Pytorch 容器 - 1. Module类介绍的文章就介绍完了。如果您还想了解更多内容,请在右上角搜索TOY模板网以前的文章或继续浏览下面的相关文章,希望大家以后多多支持TOY模板网!