Pytorch中的forward的理解

这篇具有很好参考价值的文章主要介绍了Pytorch中的forward的理解。希望对大家有所帮助。如果存在错误或未考虑完全的地方,请大家不吝赐教,您也可以点击"举报违法"按钮提交疑问。

0. 前言

按照国际惯例,首先声明:本文只是我自己学习的理解,虽然参考了他人的宝贵见解,但是内容可能存在不准确的地方。如果发现文中错误,希望批评指正,共同进步。

1. 关于forward的两个小问题

1.1 为什么都用def forward,而不改个名字?

在Pytorch建立神经元网络模型的时候,经常用到forward方法,表示在建立模型后,进行神经元网络的前向传播。说的直白点,forward就是专门用来计算给定输入,得到神经元网络输出的方法。

在代码实现中,也是用def forward来写forward前向传播的方法,我原来以为这是一种约定熟成的名字,也可以换成任意一个自己喜欢的名字。

但是看的多了之后发现并非如此:Pytorch对于forward方法赋予了一些特殊“功能”

(这里不禁再吐槽,一些看起来挺厉害的Pytorch“大神”,居然不知道这个。。。只能草草解释一下:“就是这样的。。。”)

1.2 forward有什么特殊功能?
第一条:.forward()可以不写

我最开始发现forward()的与众不同之处就是在此,首先举个例子:

import torch.nn as nn
class test(nn.Module):
    def __init__(self, input):
        super(test,self).__init__()
        self.input = input

    def forward(self,x):
        return self.input * x

T = test(8)
print(T(6))

# print(T.forward(6))
--------------------------运行结果-------------------------
D:\Users\Lenovo\anaconda3\python.exe C:/Users/Lenovo/Desktop/DL/pythonProject/tt.py
48

Process finished with exit code 0

可以发现,T(6)是可以输出的!而且不用指定,默认了调用forward方法。当然如果非要写上.forward()这也是可以正常运行的,和不写是一样的。

如果不调用Pytorch(正常的Python语法规则),这样肯定会报错的

# import torch.nn as nn  #不再调用torch
class test():
    def __init__(self, input):
        self.input = input

    def forward(self,x):
        return self.input * x

T = test(8)
print(T.forward(6))
print("************************")
print(T(6))
--------------------------运行结果-------------------------
D:\Users\Lenovo\anaconda3\python.exe C:/Users/Lenovo/Desktop/DL/pythonProject/tt.py
48
************************
Traceback (most recent call last):
  File "C:\Users\Lenovo\Desktop\DL\pythonProject\tt.py", line 77, in <module>
    print(T(6))
TypeError: 'test' object is not callable

Process finished with exit code 1

这里会报:‘test’ object is not callable
因为class不能被直接调用,不知道你想调用哪个方法。

第二条:优先运行forward方法

如果在class中再增加一个方法:

import torch.nn as nn
class test(nn.Module):
    def __init__(self, input):
        super(test,self).__init__()
        self.input = input

    def byten(self):
        return self.input * 10

    def forward(self,x):
        return self.input * x

T = test(8)
print(T(6))
print(T.byten())
--------------------------运行结果-------------------------
D:\Users\Lenovo\anaconda3\python.exe C:/Users/Lenovo/Desktop/DL/pythonProject/tt.py
48
80

Process finished with exit code 0

可以见到,在class中有多个method的时候,如果不指定method,forward是会被优先执行的。

2. 总结

在Pytorch中,forward方法是一个特殊的方法,被专门用来进行前向传播。

20230605 更新

应评论要求,增加forward的官方定义,这块我就不搬运PyTorch官网的内容了,直接传送门走你:nn.Module.forward。

20230919 大更新

首先非常感谢大家喜欢本文!这篇文章本来是我自己的“随手记”没想到有这么多C友浏览过!

其实在写完本文后我是有些遗憾的,因为本文仅是用了实验的方法探索出了.forward()的表象,而它的运作机理却没有说明白,知其然不知其所以然!

在此感谢下面 Mr·小鱼 的评论给了我启迪,因为魔术方法__call__()的特性确实很符合.forward()的表象,但是我对着nn.Module的源码一脸茫然,因为源码中压根没有__call__()方法的定义!!

于是我抱着试试的心态,在PyTorch官网上查了下PyTorch的历史版本,这一查确实查到了线索:
Pytorch中的forward的理解
下面是从PyTorch的上古版本v0.1.12中截取forward()__call__()方法的源码:

class Module(object):
#...中间不相关代码省略...
    def forward(self, *input):
        """Defines the computation performed at every call.

        Should be overriden by all subclasses.
        """
        raise NotImplementedError
#...中间不相关代码省略...
    def __call__(self, *input, **kwargs):
        result = self.forward(*input, **kwargs)
        for hook in self._forward_hooks.values():
            hook_result = hook(self, input, result)
            if hook_result is not None:
                raise RuntimeError(
                    "forward hooks should never return any values, but '{}'"
                    "didn't return None".format(hook))
        var = result
        while not isinstance(var, Variable):
            var = var[0]
        creator = var.creator
        if creator is not None and len(self._backward_hooks) > 0:
            for hook in self._backward_hooks.values():
                wrapper = functools.partial(hook, self)
                functools.update_wrapper(wrapper, hook)
                creator.register_hook(wrapper)
        return result

我们可以看到在__call__()方法中直接把方法self.forward()作为函数的返回值,由于魔术方法__call__()可以被自动调用,这也就解释了为什么forward()可以自动运行。

至于该方法中的其他内容,都是与hook钩子函数的操作相关,这部分暂不做探索。。。

那我们回到现在的版本(我现在使用的是1.8.1):
Pytorch中的forward的理解
通过源码可以看到经历了多个版本的更迭,forward()__call__()居然改名字了!!

    forward: Callable[..., Any] = _forward_unimplemented
    ...
    __call__ : Callable[..., Any] = _call_impl

这也就是为什么我之前在源码中没找到这两个方法定义的原因。。。准确来说这里也不能说是改名字了,而是多了一个名字,至于PyTorch为什么会有这样的更改,我确实也没想到原因。。。

其中_forward_unimplemented()倒是没变:

def _forward_unimplemented(self, *input: Any) -> None:
    r"""Defines the computation performed at every call.

    Should be overridden by all subclasses.

    .. note::
        Although the recipe for forward pass needs to be defined within
        this function, one should call the :class:`Module` instance afterwards
        instead of this since the former takes care of running the
        registered hooks while the latter silently ignores them.
    """
    raise NotImplementedError

_call_impl()相比于上古版本,已经复杂到了令人发指的地步!

    def _call_impl(self, *input, **kwargs):
        # Do not call functions when jit is used
        full_backward_hooks, non_full_backward_hooks = [], []
        if len(self._backward_hooks) > 0 or len(_global_backward_hooks) > 0:
            full_backward_hooks, non_full_backward_hooks = self._get_backward_hooks()

        for hook in itertools.chain(
                _global_forward_pre_hooks.values(),
                self._forward_pre_hooks.values()):
            result = hook(self, input)
            if result is not None:
                if not isinstance(result, tuple):
                    result = (result,)
                input = result

        bw_hook = None
        if len(full_backward_hooks) > 0:
            bw_hook = hooks.BackwardHook(self, full_backward_hooks)
            input = bw_hook.setup_input_hook(input)

        if torch._C._get_tracing_state():
            result = self._slow_forward(*input, **kwargs)
        else:
            result = self.forward(*input, **kwargs)
        for hook in itertools.chain(
                _global_forward_hooks.values(),
                self._forward_hooks.values()):
            hook_result = hook(self, input, result)
            if hook_result is not None:
                result = hook_result

        if bw_hook:
            result = bw_hook.setup_output_hook(result)

        # Handle the non-full backward hooks
        if len(non_full_backward_hooks) > 0:
            var = result
            while not isinstance(var, torch.Tensor):
                if isinstance(var, dict):
                    var = next((v for v in var.values() if isinstance(v, torch.Tensor)))
                else:
                    var = var[0]
            grad_fn = var.grad_fn
            if grad_fn is not None:
                for hook in non_full_backward_hooks:
                    wrapper = functools.partial(hook, self)
                    functools.update_wrapper(wrapper, hook)
                    grad_fn.register_hook(wrapper)
                self._maybe_warn_non_full_backward_hook(input, result, grad_fn)

        return result

其变复杂的原因是各种钩子函数_hook的调用,有兴趣的童鞋可以参考这篇文章:pytorch 中_call_impl()函数。这部分绝对是超纲了!

最后我想再做几个实验加深理解:
实验①

import torch.nn as nn
class test(nn.Module):
    def __init__(self, input):
        super(test,self).__init__()
        self.input = input

    def forward(self,x):
        return self.input * x

T = test(8)
print(T.__call__(6))
--------------------------运行结果-------------------------
D:\Users\Lenovo\anaconda3\python.exe C:\Users\Lenovo\Desktop\DL\Pytest\calc_graph\test.py 
48

Process finished with exit code 0

这里T.__call__(6)写法等价于T(6)

实验②

import torch.nn as nn
class test(nn.Module):
    def __init__(self, input):
        super(test,self).__init__()
        self.input = input

    def forward(self,x):
        return self.input * x

T = test(8)
print(T.forward(6))
--------------------------运行结果-------------------------
D:\Users\Lenovo\anaconda3\python.exe C:\Users\Lenovo\Desktop\DL\Pytest\calc_graph\test.py 
48

Process finished with exit code 0

这里T.forward(6)的写法虽然也能正确地计算出结果,但是不推荐这么写,因为这会导致__call__()调用一遍forward(),然后手动又调用了一遍forward(),造成forward()的重复计算,浪费计算资源。

实验③

import torch.nn as nn
class test(nn.Module):
    def __init__(self, input):
        super(test,self).__init__()
        self.input = input

    # def forward(self,x):
    #     return self.input * x

T = test(8)
print(T())
--------------------------运行结果-------------------------
D:\Users\Lenovo\anaconda3\python.exe C:\Users\Lenovo\Desktop\DL\Pytest\calc_graph\test.py 
Traceback (most recent call last):
  File "C:\Users\Lenovo\Desktop\DL\Pytest\calc_graph\test.py", line 11, in <module>
    print(T())
  File "D:\Users\Lenovo\anaconda3\lib\site-packages\torch\nn\modules\module.py", line 889, in _call_impl
    result = self.forward(*input, **kwargs)
  File "D:\Users\Lenovo\anaconda3\lib\site-packages\torch\nn\modules\module.py", line 201, in _forward_unimplemented
    raise NotImplementedError
NotImplementedError

forward()是必须要写的,因为__call__()要自动调用forward()。如果压根不写forward()__call__()将无方法可以调用。按照forward()的源码,这里会raise NotImplementedError

至此,我觉得PyTorch中的forward应该算是全说明白了。。。文章来源地址https://www.toymoban.com/news/detail-401864.html

到了这里,关于Pytorch中的forward的理解的文章就介绍完了。如果您还想了解更多内容,请在右上角搜索TOY模板网以前的文章或继续浏览下面的相关文章,希望大家以后多多支持TOY模板网!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处: 如若内容造成侵权/违法违规/事实不符,请点击违法举报进行投诉反馈,一经查实,立即删除!

领支付宝红包 赞助服务器费用

相关文章

  • cuda+anaconda+pytorch按照教程

    1、查看当前显卡支持的最高版本,有两种方式: 1)NVIDIA控制面板—帮助—系统信息—组件—NVCUDA.dll对应版本 请注意,12.2为本机CUDA支持的最高版本 nvidia-smi显示的同上,也表示cuda支持的最高版本 安装cuda后在cmd窗口使用nvcc -V命令看到的是自己安装的cuda版本,这个安装后不

    2024年02月12日
    浏览(22)
  • 从零开始理解Linux中断架构(1)-前言

    前言         前段时间在转行手撸WIFI路由器,搞wifi路由器需要理解网络驱动程序,以太网卡驱动程序,无线WIFI驱动程序,而网卡驱动的关键路径就在中断程序中,需要了解NIC设备驱动程序如何收发数据,为了彻底的知道数据包是如何二层传递上来的,又需要了解一点Lin

    2024年02月09日
    浏览(44)
  • 深入理解PyTorch中的train()、eval()和no_grad()

    ❤️觉得内容不错的话,欢迎点赞收藏加关注😊😊😊,后续会继续输入更多优质内容❤️ 👉有问题欢迎大家加关注私戳或者评论(包括但不限于NLP算法相关,linux学习相关,读研读博相关......)👈 (封面图由文心一格生成) 在PyTorch中,train()、eval()和no_grad()是三个非常重

    2023年04月08日
    浏览(36)
  • OpenCV读取图像时按照BGR的顺序HWC排列,PyTorch按照RGB的顺序CHW排列

    在OpenCV中,读取的图片默认是HWC格式,即按照高度、宽度和通道数的顺序排列图像尺寸的格式。我们看最后一个维度是C,因此最小颗粒度是C。 例如,一张形状为256×256×3的RGB图像,在OpenCV中读取后的格式为[256, 256, 3],其中最后一个维度表示图像的通道数。在OpenCV中,可以通

    2024年02月04日
    浏览(29)
  • forward函数——浅学深度学习框架中的forward

    (本应该出一篇贯穿神经网络的文章的,但是由于时间关系,就先浅浅记录一下,加深自己的理解吧吧)。 forward 函数是深度学习框架中常见的一个函数,用于定义神经网络的前向传播过程。 在训练过程中,输入数据会被传入神经网络的 forward 函数,然后经过一系列的计算和

    2023年04月27日
    浏览(29)
  • 【Pytorch】提取模型中间层输出(hook, .register_forward_hook(hook=hook))

    需要转换的对象 模型 损失函数 数据(特征数据、标签) 4.3.1 单进程多GPU训练(DP)模式 torch.nn.DataParallel 并行的多卡都是由一个进程进行控制,在进行梯度的传播时,是在主GPU上进行的。 将模型布置到多个指定GPU上 model = torch.nn.DataParallel(model,device_ids=device_list) 指定模型布置的

    2024年02月13日
    浏览(25)
  • PyTorch的CUDA错误:Error 804: forward compatibility was attempted on non supported HW

    宿主机为Ubuntu20.04 + gtx1060,Nvidia driver版本为510.85.02。 安装环境为:tensorrt8.4 安装完成后,一当调用cuda环境就会报错:Error 804: forward compatibility was attempted on non supported HW。 检查问题原因 在Linux宿主机上使用docker(版本= 19.3)之前,请确保安装了nvidia-container-runtime和nvidia-cont

    2023年04月08日
    浏览(31)
  • 31 对集合中的字符串,按照长度降序排列

            思路:使用集合的sort方法,新建一个Comparator接口,泛型是String,重写里面的compare方法。         运行结果:          扩充:点击Comparator,查看接口内部:发现加了@FunctionalInterface,说明可以使用箭头函数,直接使用箭头函数就能表示Comparator接口以及它的compara

    2024年02月14日
    浏览(34)
  • 神经网络中的前向传播(Forward Propagation)和后向传播(Backward Propagation)

    有时候会搞混这两个概念。什么是前向传播?不是只有后向传播吗?后向传播好像是用来更新模型参数的,前向传播是什么东西? 带着疑问再次梳理一遍: 前向传播是神经网络进行预测的过程。在这个过程中,输入数据沿着神经网络从输入层经过隐藏层(如果有的话)最终

    2024年02月20日
    浏览(31)
  • nginx负载转发源请求http/https:X-Forwarded-Proto及nginx中的转发报头

    今天在排查服务器的问题时最后定位到服务器因为经过了运维这一层的处理,转发过来的请求不管用户请求的是https还是http,我们的proxy服务器收到的都是80端口上的http。于是联系相关部门了解有没有现成的可用的这样一个字段来获得这个值。公司用的也是标准报头,即X-Fo

    2024年02月16日
    浏览(47)

觉得文章有用就打赏一下文章作者

支付宝扫一扫打赏

博客赞助

微信扫一扫打赏

请作者喝杯咖啡吧~博客赞助

支付宝扫一扫领取红包,优惠每天领

二维码1

领取红包

二维码2

领红包