torch.autograd.Function的使用

这篇具有很好参考价值的文章主要介绍了torch.autograd.Function的使用。希望对大家有所帮助。如果存在错误或未考虑完全的地方,请大家不吝赐教,您也可以点击"举报违法"按钮提交疑问。

(个人理解仅供参考)

1 什么情况下使用

自己定义的网络结构,没有现成的,就得手写forward和backward

2 怎么使用

2.1 forward

前向传播的表达式

2.2 backward

求导结果

2.3 举例

前向传播表达式:y = w * x + b
假设f()是我们关于y的loss函数,那么z = f(y)即为loss值
现在要求loss对w、x、b的偏导(假设只有一层):
dz/dx = dz/dy * dy/dx = dz/dy * w
dz/dw = dz/dy * dy/dw = dz/dy * x
dz/db = dz/dy * dy/db = dz/dy * 1
好在dz/dy不用我们再求了,它就是 backward 的参数grad_output。那么grad_output是从哪来的呢?其实就是 forward 会 return output 给 backward ,至于 backward 怎么把 output 变为 grad_output 就不用细究了。
所以:
dz/dx = grad_output * w
dz/dw = grad_output * x
dz/db = grad_output * 1

因此,对于y = w * x + b,我们的代码为:

import torch
from torch.autograd import Function

class MultiplyAdd(Function):

    @staticmethod
    def forward(ctx, w, x, b):
        ctx.save_for_backward(w, x)	 # 保存参数
        output = w * x + b
        return output	# 传给backward

    @staticmethod
    def backward(ctx, grad_output):
        w, x = ctx.saved_tensors
        grad_w = grad_output * x
        grad_x = grad_output * w
        grad_b = grad_output * 1
        return grad_w, grad_x, grad_b	# 传给forward


Linear = MultiplyAdd.apply

2.4 模板

"""
# 使用autograd.Function进行扩展的一个模板
class My_Function(Function):
    def forward(self, inputs, parameters):
        self.saved_for_backward = [inputs, parameters]
        # output = [对输入和参数进行的操作,其实就是前向运算的函数表达式]
        return output
 
    def backward(self, grad_output):
        inputs, parameters = self.saved_tensors # 或self.saved_variables
        # grad_input = [求函数forward(input)关于 parameters 的导数,其实就是反向运算的导数表达式] * grad_output
        return grad_input
"""

2.5 验证

验证的话需要使用torch.autograd.gradcheck,给上我的完整代码,验证部分在最后:

import torch
from torch.autograd import Function, gradcheck

"""
# 使用autograd.Function进行扩展的一个模板
class My_Function(Function):
    def forward(self, inputs, parameters):
        self.saved_for_backward = [inputs, parameters]
        # output = [对输入和参数进行的操作,其实就是前向运算的函数表达式]
        return output
 
    def backward(self, grad_output):
        inputs, parameters = self.saved_tensors # 或者是self.saved_variables
        # grad_input = [求函数forward(input)关于 parameters 的导数,其实就是反向运算的导数表达式] * grad_output
        return grad_input
"""


class MultiplyAdd(Function):

    @staticmethod
    def forward(ctx, w, x, b):
        ctx.save_for_backward(w, x)
        output = w * x + b
        return output

    @staticmethod
    def backward(ctx, grad_output):
        w, x = ctx.saved_tensors
        grad_w = grad_output * x
        grad_x = grad_output * w
        grad_b = grad_output * 1
        return grad_w, grad_x, grad_b


Linear = MultiplyAdd.apply

x = torch.ones(1, requires_grad=True, dtype=torch.float64)
w = torch.rand(1, requires_grad=True, dtype=torch.float64)
b = torch.rand(1, requires_grad=True, dtype=torch.float64)

# print("start forward...")
# z = MultiplyAdd.apply(w, x, b)
# print("start backward...")
# z.backward()
#
# print(x.grad, w.grad, b.grad)

test = gradcheck(Linear, (x, w, b), eps=1e-6)
print(test)

3 存疑

现在我只是会用了这个,但是如果是两层的全连接层,这段代码是怎么工作的?这个问题我还没想明白,留个坑文章来源地址https://www.toymoban.com/news/detail-410605.html

到了这里,关于torch.autograd.Function的使用的文章就介绍完了。如果您还想了解更多内容,请在右上角搜索TOY模板网以前的文章或继续浏览下面的相关文章,希望大家以后多多支持TOY模板网!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处: 如若内容造成侵权/违法违规/事实不符,请点击违法举报进行投诉反馈,一经查实,立即删除!

领支付宝红包 赞助服务器费用

相关文章

  • Office 2019 激活-探索(仅供参考)

    1. 新建txt文件 2. 在新建txt文件中输入 @echooff (cd /d \\\"%~dp0\\\")(NET FILE||(powershell start-process -FilePath \\\'%0\\\' -verb runas)(exit /B)) NUL 21 title Office 2019 Activator r/Piracy echo Converting... mode 40,25 (if exist \\\"%ProgramFiles%Microsoft OfficeOffice16ospp.vbs\\\" cd /d \\\"%ProgramFiles%Microsoft OfficeOffice16\\\")(if exist \\\"%ProgramFiles(x

    2023年04月08日
    浏览(59)
  • LCD1602操作指令(仅供参考)

    1.清屏指令( 0000 0001 ) 1.清除液晶显示器,即将DDRAM的内容全部清除。 2.光标回到液晶屏左上方。 3.地址计数器(AC)的值设置为0。 2.光标归位指令(0000 001x) 1.把光标返回到液晶屏左上方。 2.把地址计数器(AC)的值设置为0。 3.保持DDRAM的内容不变。 3.模式设置指令(0000

    2024年02月05日
    浏览(34)
  • restful接口设计规范[仅供参考]

    应该尽量将API部署在专用域名之下。 如果确定API很简单,不会有进一步扩展,可以考虑放在主域名下。 应该将API的版本号放入URL。 另一种做法是,将版本号放在HTTP头信息中,但不如放入URL方便和直观。Github就采用了这种做法。 因为不同的版本,可以理解成同一种资源的不

    2024年02月15日
    浏览(45)
  • 我的医学预测模型评价步骤(仅供参考)

    个人意见,仅供参考 一切变化都是源于决策曲线分析,据说决策曲线分析已经获得了预测模型界的认可,也已经被写进了预测模型的报告指南–TRIPOD 中。一篇在pubmed上发表的关于如何使用决策曲线分析的指导论文,给出了使用决策曲线分析的几点推荐:1. 确定临床使用场景

    2024年02月02日
    浏览(60)
  • halcon不能连接海康相机解决方法(仅供参考)

    halcon不能连接相机有很多原因,这里作者给出其中一种的解决方法。 首先需要自行先下载海康软件,   1.首先点开Development, 2.根据图片的路径,点开以下文件夹  3.根据自己电脑安装的halcon版本打开对应文件夹     4.我的电脑是win64位的,根据自己的电脑打开对应的文件夹。

    2024年02月12日
    浏览(40)
  • uniapp获取手机号(前端部分,仅供参考~)

    html部分 js部分 api部分

    2024年02月09日
    浏览(53)
  • 有关 Rust 交叉编译的一些思路 (仅供参考)

    近来, 使用 Rust 语言开发的应用程序, 渐渐融入了开发者以及普通用户的日常生活. 它们不仅出现在我们常用的工作平台上, 不少嵌入式设备或者云服务器上也多见它们的身影. Rust 是一种需要编译的语言, 且一些 crate 仍需要 C/C++ 的构建环境. 大多数时候, 在 Rust 工具链 (toolchai

    2024年02月09日
    浏览(42)
  • 分布式计算----期末复习题(仅供参考)

    一.单选题,每个2分 1.Hadoop 之父 是下面的哪一位?(B) A. James Gosling        B.Doug Cutting    C.Matei Zaharia   D.Linus Benedict Torvalds 2.Hadoop中,用于 处理或者分析海量数据 的组件是哪一个?(  B   ) A.HDFS     B.MapReduce     C.Yarn   D.以上选项都不是 3.HDFS中 存储和管理元数据

    2024年02月10日
    浏览(49)
  • 【Software Testing】【期末习题库】【2023年春】【仅供参考】

    类型 总分占比 平时成绩 40% 考试/考查 60% 题型 题量×分值 备注 单选 20×1’ 多选 10×3’ 全对=3’,错1个=0’,少选=-1’ 填空 10×2’ 判断 5×2’ 大题 2×10’ 平时习题(3次): ①软件测试概述 ②黑盒测试 ③白盒测试和性能测试 期中考试(1次) 大题1:平时课上练习过的习题

    2024年02月10日
    浏览(45)
  • 删除文件后磁盘空间未释放,只能重启进程?(仅供参考)

    很多运维同学都遇到过“磁盘告警”,遇到这种情况就需要去清理磁盘。 这时候,很多同学通过各种途径、手段、命令找到了占用磁盘比较大的文件,然后大手一挥,  以为这样任务就完成了,谁知道,一查询磁盘使用量还是居高不下,完全没有释放。 这是因为在Linux中,如

    2024年02月11日
    浏览(35)

觉得文章有用就打赏一下文章作者

支付宝扫一扫打赏

博客赞助

微信扫一扫打赏

请作者喝杯咖啡吧~博客赞助

支付宝扫一扫领取红包,优惠每天领

二维码1

领取红包

二维码2

领红包