【一起撸个DL框架】5 实现:自适应线性单元

这篇具有很好参考价值的文章主要介绍了【一起撸个DL框架】5 实现:自适应线性单元。希望对大家有所帮助。如果存在错误或未考虑完全的地方,请大家不吝赐教,您也可以点击"举报违法"按钮提交疑问。

  • CSDN个人主页:清风莫追
  • 欢迎关注本专栏:《一起撸个DL框架》
  • GitHub获取源码:https://github.com/flying-forever/OurDL
  • blibli视频合集:https://space.bilibili.com/3493285974772098/channel/series

5 实现:自适应线性单元🍇

1 简介

上一篇:【一起撸个DL框架】4 反向传播求梯度

上一节我们实现了计算图的反向传播,可以求结果节点关于任意节点的梯度。下面我们将使用梯度来更新参数,实现一个简单的自适应线性单元

我们本次拟合的目标函数是一个简单的线性函数: y = 2 x + 1 y=2x+1 y=2x+1,通过随机数生成一些训练数据,将许多组x和对应的结果y值输入模型,但是并不告诉模型具体函数中的系数参数“2”和偏置参数“1”,看看模型能否通过数据“学习”到参数的值。

【一起撸个DL框架】5 实现:自适应线性单元
图1:自适应线性单元的计算图

2 损失函数

2.1 梯度下降法

损失是对模型好坏的评价指标,表示模型输出结果与正确答案(也称为标签)之间的差距。所以损失值越小就说明模型越准确,训练过程的目的便是最小化损失函数的值

自适应线性单元是一个回归任务,我们这里将使用绝对值损失,将模型输出与正确答案之间的差的绝对值作为损失函数的值,即 l o s s = ∣ l − a d d ∣ loss=|l-add| loss=ladd

评价指标有了,可是如何才能达标呢?或者说如何才能降低损失函数的值?计算图中有四个变量: x , w , b , l x,w,b,l x,w,b,l,而我们训练过程的任务是调整参数 w , b w,b w,b的值,以降低损失。因此训练过程中的自变量是w和b,而把x和l看作常量。此时损失函数是关于w和b的二元函数 l o s s = f ( w , b ) loss=f(w,b) loss=f(w,b),我们只需要求函数的梯度 ▽ f ( w , b ) = ( ∂ f ∂ w , ∂ f ∂ b ) \triangledown f(w,b)=(\frac{\partial f}{\partial w},\frac{\partial f}{\partial b}) f(w,b)=(wf,bf),则梯度的反方向就是函数下降最快的方向。沿着梯度的方向更新参数w和b的值,就可以降低损失。这就是经典的优化算法:梯度下降法

2.2 补充

关于损失和优化的概念,大家可能还是有些模糊。上面损失只讲到了一个输入x值对应的模型输出与实际结果之间的差距,但使用整个数据集的平均差距可能更容易理解,就像中学的线性回归

图2所示,改变直线的斜率w,将改变直线与数据点的贴近程度,即改变了损失函数loss的值。

【一起撸个DL框架】5 实现:自适应线性单元
图2:损失与参数更新示意图

参考: 【深度学习】3-从模型到学习的思路整理_清风莫追的博客-CSDN博客

3 整理项目结构

我们的小项目的代码也渐渐多起来了,好的目录结构将使它更加易于扩展。关于python包结构的知识大家可以自行去了解,大致目录结构如下:

- example
- ourdl
	- core
		- __init__.py
		- node.py
	- ops
		- __init__.py
		- loss.py
		- ops.py
	__init__.py

给这个简单框架的名字叫做OurDL,使用框架搭建的计算图等程序放在example目录下。在ourdl/core/node.py中存放了节点基类和变量类的定义,在ourdl/ops/下存放了运算节点的定义,包括损失函数和加法、乘法节点等。

4 损失函数的实现

/ourdl/ops/loss.py中,

from ..core import Node

class ValueLoss(Node):
    '''损失函数:作差取绝对值'''
    def compute(self):
        self.value = self.parent1.value - self.parent2.value
        self.flag = self.value > 0
        if not self.flag:
            self.value = -self.value
    def get_parent_grad(self, parent):
        a = 1 if self.flag else -1
        b = 1 if parent == self.parent1 else -1
        return a * b

其中compute()方法很显然就是对两个输入作差取绝对值;get_parent_grad()方法求本节点关于父节点的梯度。有绝对值如何求梯度?大家可以画一画绝对值函数的图像。

5 修改节点类(Node)

ourdl/core/node.py

class Node:
    pass  # 省略了一些方法的定义,大家可以查看上一篇文章

    def clear(self):
        '''递归清除父节点的值和梯度信息'''
        self.grad = None
        if self.parent1 is not None:  # 清空非变量节点的值
            self.value = None
        for parent in [self.parent1, self.parent2]:
            if parent is not None:
                parent.clear()
    def update(self, lr=0.001):
        '''根据本节点的梯度,更新本节点的值'''
        self.value -= lr * self.grad  # 减号表示梯度的反方向

我在节点类中新增了两个方法,其中clear()用于清除多余的节点值和梯度信息,因为当节点值或梯度已经存在时会直接返回结果而不会递归去求了(get_grad()forward()的代码)。update()有一个学习率参数lr,更新幅度太大可能导致参数值一直在目标值左右晃悠,无法收敛

6 自适应线性单元

/example/01_esay/自适应线性单元.py

import sys
sys.path.append('../..')
from ourdl.core import Varrible
from ourdl.ops import Mul, Add
from ourdl.ops.loss import ValueLoss

if __name__ == '__main__':
    # 搭建计算图
    x = Varrible()
    w = Varrible()
    mul = Mul(parent1=x, parent2=w)
    b = Varrible()
    add = Add(parent1=mul, parent2=b)
    label = Varrible()
    loss = ValueLoss(parent1=label, parent2=add)
    # 参数初始化
    w.set_value(0)
    b.set_value(0)
    # 生成训练数据
    import random
    data_x = [random.uniform(-10, 10) for i in range(10)]  # 按均匀分布生成[-10, 10]范围内的随机实数
    data_label = [2 * data_x_one + 1 for data_x_one in data_x]
    # 开始训练
    for i in range(len(data_x)):
        x.set_value(data_x[i])
        label.set_value(data_label[i])
        loss.forward()  # 前向传播 --> 求梯度会用到损失函数的值
        w.get_grad()
        b.get_grad()
        w.update(lr=0.05)
        b.update(lr=0.1)
        loss.clear()
        print("w:{:.2f}, b:{:.2f}".format(w.value, b.value))
    print("最终结果:{:.2f}x+{:.2f}".format(w.value, b.value))
    

运行结果:

w:0.13, b:0.10
w:0.36, b:0.20
w:0.58, b:0.10
w:0.74, b:0.00
w:1.13, b:0.10
w:1.43, b:0.20
w:1.62, b:0.30
w:1.94, b:0.20
w:1.50, b:0.30
w:1.87, b:0.40
最终结果:1.87x+0.40

上面自适应线性单元的训练,已经能够大致展现深度学习模型的训练流程:

  • 搭建模型 --> 初始化参数 --> 准备数据 --> 使用数据更新参数的值

我们这里参数只更新了10次,结果就已经大致接近了我们的目标函数 y = 2 x + 1 y=2x+1 y=2x+1。大家可以试试更改学习率lr,训练数据集的大小,观察运行结果会发生怎样的变化。(必备技能:调参)


下一篇:【一起撸个深度学习框架】6 折与曲的相会——激活函数文章来源地址https://www.toymoban.com/news/detail-437944.html

到了这里,关于【一起撸个DL框架】5 实现:自适应线性单元的文章就介绍完了。如果您还想了解更多内容,请在右上角搜索TOY模板网以前的文章或继续浏览下面的相关文章,希望大家以后多多支持TOY模板网!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处: 如若内容造成侵权/违法违规/事实不符,请点击违法举报进行投诉反馈,一经查实,立即删除!

领支付宝红包 赞助服务器费用

相关文章

  • 【一起撸个深度学习框架】6 折与曲的相会——激活函数

    CSDN个人主页:清风莫追 欢迎关注本专栏:《一起撸个DL框架》 GitHub获取源码:https://github.com/flying-forever/OurDL blibli视频合集:https://space.bilibili.com/3493285974772098/channel/series 在上一节,我们实现了一个“自适应线性单元”,不断地将一个一次函数的输入和输出“喂”给它,它就

    2024年02月05日
    浏览(42)
  • 一起学数据结构(2)——线性表及线性表顺序实现

    目录 1. 什么是数据结构:  1.1 数据结构的研究内容: 1.2 数据结构的基本概念: 1.2.1 逻辑结构:  1.2.2 存储结构: 2. 线性表: 2.1 线性表的基本定义: 2.2 线性表的运用: 3 .线性表的顺序表示及实现(顺序表):    3.1 顺序表的概念及结构:  3.2 顺序表的代码实现: 3.2

    2024年02月14日
    浏览(46)
  • flask框架-认证权限(一):使用g对象存登录用户信息,认证权限一起实现

    apps         -user         __init__.py authen        __init__.py         token.py ext         __init__.py util.py        public.py         __init__.py app.py 依赖包 authen/token.py user/views.py 认证大致的逻辑: 1、用户登录时,生成token,前端保存token信息 2、前端发起请求时,将token携带在cook

    2024年02月09日
    浏览(44)
  • 用于永磁同步电机驱动器的自适应SDRE非线性无传感器速度控制(Matlab&Simulink实现)

    目录 💥1 概述 📚2 运行结果 🎉3 参考文献 🌈4 Matlab代码Simulink仿真实现 本文方法基于状态依赖的里卡蒂方程(SDRE)控制技术及其梯度型神经网络的实时计算方法,允许在线控制PMSM。 为了实现用于永磁同步电机驱动器的自适应 SDRE(State-Dependent Riccati Equation)非线性无传感

    2024年02月15日
    浏览(45)
  • linux环境编程(1): 实现一个单元测试框架-2

    在之前的文章中, 介绍了如何实现一个类似gtest的单元测试框架, 完整的项目代码可以参考这里: https://github.com/kfggww/cutest . 近期对cutest的实现做了一些修改, 包括: Test Suite的声明宏, 修改为TEST_SUITE 增加Test Suite的声明宏TEST_SUITE_WITH. 可传递Suite的init和cleanup函数, 在Suite中每个Cas

    2024年02月12日
    浏览(35)
  • ​基于多种语言,使用Selenium实现自动化的常用单元测试框架

    Selenium是自动化网络应用程序的首选工具。Selenium支持基于Java、C#、PHP、Ruby、Perl、JavaScript和Python等多种编程语言的各种单元测试框架。这些框架用于在 Windows、MacOS 和 Linux 等不同平台的网络应用程序上执行测试脚本。任何成功的自动化流程都有赖于强大的测试框架,这些框架

    2024年01月21日
    浏览(55)
  • 【MFAC】基于紧格式动态线性化的无模型自适应控制

    来源:侯忠生教授的《无模型自适应控制:理论与应用》(2013年科学出版社)。 👉对应书本 3.2 单输入单输出系统(SISO)紧格式动态线性化(CFDL) 和 4.2 单输入单输出系统(SISO)紧格式动态线性化(CFDL)的无模型自适应控制(MFAC) 紧格式动态线性化 (compact form dynamic linearization) SISO离散

    2024年02月02日
    浏览(42)
  • 支持向量机SVM(包括线性核、多项式核、高斯核)python手写实现+代码框架说明

    理论参考《统计学习方法》Chapter.7 支持向量机(SVM) 完整代码见github仓库:https://github.com/wjtgoo/SVM-python 借鉴sklearn的代码构架,整体功能实现在SVM类中,包括各种类属性,以及常用的模型训练函数 SVM.fit(x,y,iterations) ,以及预测函数 SVM.predict(x) , 类输入参数 kernal: 默认:线性

    2023年04月17日
    浏览(89)
  • 【一起学Rust | 框架篇 | Frui框架】rust一个对开发者友好的GUI框架——Frui

    本次内容接上回《rust原生跨平台GUI框架——iced》,最近突然涌现出多个Rust的UI框架,真实令人兴奋,同时也突出了Rust的勃然生机,我将尽量为大家介绍Rust领域的UI框架,带大家体验Rust领域的好玩意儿。 Frui是一个对开发者相当友好的UI框架,它使得开发者构建用户界面变得

    2024年02月01日
    浏览(42)

觉得文章有用就打赏一下文章作者

支付宝扫一扫打赏

博客赞助

微信扫一扫打赏

请作者喝杯咖啡吧~博客赞助

支付宝扫一扫领取红包,优惠每天领

二维码1

领取红包

二维码2

领红包