(个人理解仅供参考)
1 什么情况下使用
自己定义的网络结构,没有现成的,就得手写forward和backward
2 怎么使用
2.1 forward
前向传播的表达式
2.2 backward
求导结果
2.3 举例
前向传播表达式:y = w * x + b
假设f()是我们关于y的loss函数,那么z = f(y)即为loss值
现在要求loss对w、x、b的偏导(假设只有一层):dz/dx
= dz/dy
* dy/dx
= dz/dy
* w
dz/dw
= dz/dy
* dy/dw
= dz/dy
* x
dz/db
= dz/dy
* dy/db
= dz/dy
* 1
好在dz/dy
不用我们再求了,它就是 backward 的参数grad_output
。那么grad_output
是从哪来的呢?其实就是 forward 会 return output 给 backward ,至于 backward 怎么把 output 变为 grad_output 就不用细究了。
所以:dz/dx
= grad_output
* w
dz/dw
= grad_output
* x
dz/db
= grad_output
* 1
因此,对于y = w * x + b,我们的代码为:
import torch
from torch.autograd import Function
class MultiplyAdd(Function):
@staticmethod
def forward(ctx, w, x, b):
ctx.save_for_backward(w, x) # 保存参数
output = w * x + b
return output # 传给backward
@staticmethod
def backward(ctx, grad_output):
w, x = ctx.saved_tensors
grad_w = grad_output * x
grad_x = grad_output * w
grad_b = grad_output * 1
return grad_w, grad_x, grad_b # 传给forward
Linear = MultiplyAdd.apply
2.4 模板
"""
# 使用autograd.Function进行扩展的一个模板
class My_Function(Function):
def forward(self, inputs, parameters):
self.saved_for_backward = [inputs, parameters]
# output = [对输入和参数进行的操作,其实就是前向运算的函数表达式]
return output
def backward(self, grad_output):
inputs, parameters = self.saved_tensors # 或self.saved_variables
# grad_input = [求函数forward(input)关于 parameters 的导数,其实就是反向运算的导数表达式] * grad_output
return grad_input
"""
2.5 验证
验证的话需要使用torch.autograd.gradcheck,给上我的完整代码,验证部分在最后:文章来源:https://www.toymoban.com/news/detail-410605.html
import torch
from torch.autograd import Function, gradcheck
"""
# 使用autograd.Function进行扩展的一个模板
class My_Function(Function):
def forward(self, inputs, parameters):
self.saved_for_backward = [inputs, parameters]
# output = [对输入和参数进行的操作,其实就是前向运算的函数表达式]
return output
def backward(self, grad_output):
inputs, parameters = self.saved_tensors # 或者是self.saved_variables
# grad_input = [求函数forward(input)关于 parameters 的导数,其实就是反向运算的导数表达式] * grad_output
return grad_input
"""
class MultiplyAdd(Function):
@staticmethod
def forward(ctx, w, x, b):
ctx.save_for_backward(w, x)
output = w * x + b
return output
@staticmethod
def backward(ctx, grad_output):
w, x = ctx.saved_tensors
grad_w = grad_output * x
grad_x = grad_output * w
grad_b = grad_output * 1
return grad_w, grad_x, grad_b
Linear = MultiplyAdd.apply
x = torch.ones(1, requires_grad=True, dtype=torch.float64)
w = torch.rand(1, requires_grad=True, dtype=torch.float64)
b = torch.rand(1, requires_grad=True, dtype=torch.float64)
# print("start forward...")
# z = MultiplyAdd.apply(w, x, b)
# print("start backward...")
# z.backward()
#
# print(x.grad, w.grad, b.grad)
test = gradcheck(Linear, (x, w, b), eps=1e-6)
print(test)
3 存疑
现在我只是会用了这个,但是如果是两层的全连接层,这段代码是怎么工作的?这个问题我还没想明白,留个坑文章来源地址https://www.toymoban.com/news/detail-410605.html
到了这里,关于torch.autograd.Function的使用的文章就介绍完了。如果您还想了解更多内容,请在右上角搜索TOY模板网以前的文章或继续浏览下面的相关文章,希望大家以后多多支持TOY模板网!