LSTM从入门到精通(形象的图解,详细的代码和注释,完美的数学推导过程)

这篇具有很好参考价值的文章主要介绍了LSTM从入门到精通(形象的图解,详细的代码和注释,完美的数学推导过程)。希望对大家有所帮助。如果存在错误或未考虑完全的地方,请大家不吝赐教,您也可以点击"举报违法"按钮提交疑问。

lstm使用教程,lstm,深度学习,机器学习,人工智能,Powered by 金山文档

先附上这篇文章的一个思维导图

lstm使用教程,lstm,深度学习,机器学习,人工智能,Powered by 金山文档

  1. 什么是RNN

按照八股文来说:RNN实际上就是一个带有 记忆的时间序列的预测模型

RNN的细胞结构图如下:

lstm使用教程,lstm,深度学习,机器学习,人工智能,Powered by 金山文档
softmax激活函数只是我举的一个例子,实际上得到y<t>也可以通过其他的激活函数得到

其中a<t-1>代表t-1时刻隐藏状态,a<t>代表经过X<t>这一t时刻的输入之后,得到的新的隐藏状态。公式主要是a<t> = tanh(Waa * a<t-1> + Wax * X<t> + b1) ;大白话解释一下就是,X<t>是今天的吊针,a<t-1>是昨天的发烧度数39,经过今天这一针之后,a<t>变成38度。这里的记忆体现在今天的38度是在前一天的基础上,通过打吊针来达到第二天的降温状态。


1.1 RNN的应用

由于RNN的记忆性,我们最容易想到的就是RNN在自然语言处理方面的应用,譬如下面这张图,提前预测出下一个字。

lstm使用教程,lstm,深度学习,机器学习,人工智能,Powered by 金山文档

除此之外,RNN的应用还包括下面的方向:

  1. 语言模型:RNN被广泛应用于语言模型的建模中,例如自然语言处理、机器翻译、语音识别等领域。

  1. 时间序列预测:RNN可以用于时间序列预测,例如股票价格预测、气象预测、心电图信号预测等。

  1. 生成模型:RNN可以用于生成模型,例如文本生成、音乐生成、艺术创作等。

  1. 强化学习:RNN可以用于强化学习中,例如在游戏、机器人控制和决策制定等领域。


1.2 RNN的缺陷

想必大家一定听说过LSTM,没错,就是由于RNN的尿性,所以才出现LSTM这一更精妙的时间序列预测模型的设计。但是我们知己知彼才能百战百胜,因此我还是决定详细分析一下RNN的缺陷,看过一些资料,但是只是肤浅的提到了梯度消失和梯度爆炸,没有实际的数学推导,这可不是一个求学之人应该有的态度!

主要的缺陷是两个:

  1. 长期依赖问题导致的梯度消失:众所周知RNN模型是一个具有记忆的模型,每一次的预测都和当前输入以及之前的状态有关,但是我们试想,如果我们的句子很长,他在第1000个记忆细胞还能记住并很好的利用第1个细胞的记忆状态吗?答案显然是否定的

  1. 梯度爆炸

lstm使用教程,lstm,深度学习,机器学习,人工智能,Powered by 金山文档

1.2.1 梯度消失和梯度爆炸的详细公式推导

敲黑板(手写公式推导,大家最迷糊的地方):

根据下面图示的例子,我手写并反复检查了自己的过程(下图),请各位看官务必认真看看,理解起来并不难,对于别的文章随口一提的梯度消失和梯度爆炸实在是透彻太多啦!!!

我们假设损失函数 lstm使用教程,lstm,深度学习,机器学习,人工智能,Powered by 金山文档 ,Y是实际值,O是预测值;首先,我们假设只有 三层,然后通过三层我们就能以此类推找出规律。反向传播我们需要对Wo,Wx,Ws,b四个变量都求偏导,在这里我们主要对Wx求偏导,其他三个以此类推,就很简单了。为了表示更清晰,笔者使用紫色的x表示乘法。
lstm使用教程,lstm,深度学习,机器学习,人工智能,Powered by 金山文档
lstm使用教程,lstm,深度学习,机器学习,人工智能,Powered by 金山文档
lstm使用教程,lstm,深度学习,机器学习,人工智能,Powered by 金山文档

根据推导的公式我们得到一个指数函数lstm使用教程,lstm,深度学习,机器学习,人工智能,Powered by 金山文档,我们在高中时候就学到过指数函数的变化系数是极大的,因此在t趋于比较大的时候(也就是我们的句子比较长的时候),如果lstm使用教程,lstm,深度学习,机器学习,人工智能,Powered by 金山文档比1小不少,那么模型的一部分梯度会趋于0,因此优化会几乎停止;同理,如果lstm使用教程,lstm,深度学习,机器学习,人工智能,Powered by 金山文档比1大一些,那么模型的部分梯度会极大,导致模型和的变化无法控制,优化毫无意义。


  1. 什么是LSTM

八股文解释:LSTM(长短时记忆网络)是一种常用于处理序列数据的深度学习模型,与传统的 RNN(循环神经网络)相比,LSTM引入了三个门( 输入门、遗忘门、输出门,如下图所示)和一个 细胞状态(cell state),这些机制使得LSTM能够更好地处理序列中的长期依赖关系。注意:小蝌蚪形状表示的是sigmoid激活函数
lstm使用教程,lstm,深度学习,机器学习,人工智能,Powered by 金山文档
Ct是细胞状态(记忆状态), lstm使用教程,lstm,深度学习,机器学习,人工智能,Powered by 金山文档是输入的信息, lstm使用教程,lstm,深度学习,机器学习,人工智能,Powered by 金山文档是隐藏状态(基于 lstm使用教程,lstm,深度学习,机器学习,人工智能,Powered by 金山文档得到的)

用最朴素的语言解释一下三个门,并且用两门考试来形象的解释一下LSTM:

  1. 遗忘门:通过x和ht的操作,并经过sigmoid函数,得到0,1的向量,0对应的就代表之前的记忆某一部分要忘记,1对应的就代表之前的记忆需要留下的部分 ===>lstm使用教程,lstm,深度学习,机器学习,人工智能,Powered by 金山文档代表复习上一门线性代数所包含的记忆,通过遗忘门,忘记掉和下一门高等数学无关的内容(比如矩阵的秩)

  1. 输入门:通过将之前的需要留下的信息和现在需要记住的信息相加,也就是得到了新的记忆状态。===>lstm使用教程,lstm,深度学习,机器学习,人工智能,Powered by 金山文档代表复习下一门科目高等数学的时候输入的一些记忆(比如洛必达法则等等),那么已经线性代数残余且和高数相关的部分(比如数学运算)+高数的知识=新的记忆状态

  1. 输出门:整合lstm使用教程,lstm,深度学习,机器学习,人工智能,Powered by 金山文档,得到一个输出===>lstm使用教程,lstm,深度学习,机器学习,人工智能,Powered by 金山文档代表高数所需要的记忆,但是在实际的考试不一定全都发挥出来考到100分。因此,lstm使用教程,lstm,深度学习,机器学习,人工智能,Powered by 金山文档则代表实际的考试分数

为了便于大家理解,附上几张非常好的图供大家理解完整的数据处理的流程:

遗忘门:

lstm使用教程,lstm,深度学习,机器学习,人工智能,Powered by 金山文档

输入门:

lstm使用教程,lstm,深度学习,机器学习,人工智能,Powered by 金山文档

细胞状态:

lstm使用教程,lstm,深度学习,机器学习,人工智能,Powered by 金山文档

输出门:

lstm使用教程,lstm,深度学习,机器学习,人工智能,Powered by 金山文档

2.1 LSTM的模型结构

这里有两张别的博主的很好的图,我在初学的时候也是恍然大悟:

图的出处

lstm使用教程,lstm,深度学习,机器学习,人工智能,Powered by 金山文档
lstm使用教程,lstm,深度学习,机器学习,人工智能,Powered by 金山文档

解释一下pytorch训练lstm所使用的参数:

  1. 这是利用pytorch调用LSTM所使用的参数

output,(h_n,c_n) = lstm (x, [ht_1, ct_1]),一般直接放入x就好,后面中括号的不用管
  1. 这是作为x(输入)喂给LSTM的参数

x:[seq_length, batch_size, input_size],这里有点反人类,batch_size一般都是放在开始的位置
  1. 这是pytorch简历LSTM是所需参数

lstm = LSTM(input_size,hidden_size,num_layers)

2.2 LSTM相比RNN的优势

LSTM的反向传播的数学推导很繁琐,因为涉及到的变量很多,但是LSTM确实是可以在一定程度上解决梯度消失和梯度爆炸的问题。我简单说一下,RNN的连乘主要是W的连乘,而W是一样的,因此就是一个指数函数(在梯度中出现指数函数并不是一件友好的事情);相反,LSTM的连乘是lstm使用教程,lstm,深度学习,机器学习,人工智能,Powered by 金山文档lstm使用教程,lstm,深度学习,机器学习,人工智能,Powered by 金山文档的偏导的不断累乘,如果前后的记忆差别不大,那偏导的值就是1,那就是多个1相乘。当然,也可能出现某一一些偏导的值很大,但是一定不会很多(换句话说,一句话的前后没有逻辑,那完全没有训练的必要)。


2.3 pytorch实现LSTM对股票的预测(实战)

需要安装一下tushare的金融方面的数据集,代码的注解我已经写的很清楚了

#!/usr/bin/python3
# -*- encoding: utf-8 -*-

import matplotlib.pyplot as plt
import numpy as np
import tushare as ts
import pandas as pd
import torch
from torch import nn
import datetime
import time

DAYS_FOR_TRAIN = 10


class LSTM_Regression(nn.Module):
    """
        使用LSTM进行回归

        参数:
        - input_size: feature size
        - hidden_size: number of hidden units
        - output_size: number of output
        - num_layers: layers of LSTM to stack
    """

    def __init__(self, input_size, hidden_size, output_size=1, num_layers=2):
        super().__init__()

        self.lstm = nn.LSTM(input_size, hidden_size, num_layers)
        self.fc = nn.Linear(hidden_size, output_size)

    def forward(self, _x):
        x, _ = self.lstm(_x)  # _x is input, size (seq_len, batch, input_size)
        s, b, h = x.shape
        x = x.view(s * b, h)
        x = self.fc(x)
        x = x.view(s, b, -1)  # 把形状改回来
        return x


def create_dataset(data, days_for_train=5) -> (np.array, np.array):
    """
        根据给定的序列data,生成数据集

        数据集分为输入和输出,每一个输入的长度为days_for_train,每一个输出的长度为1。
        也就是说用days_for_train天的数据,对应下一天的数据。

        若给定序列的长度为d,将输出长度为(d-days_for_train+1)个输入/输出对
    """
    dataset_x, dataset_y = [], []
    for i in range(len(data) - days_for_train):
        _x = data[i:(i + days_for_train)]
        dataset_x.append(_x)
        dataset_y.append(data[i + days_for_train])
    return (np.array(dataset_x), np.array(dataset_y))


if __name__ == '__main__':
    t0 = time.time()
    data_close = ts.get_k_data('000001', start='2019-01-01', index=True)['close']  # 取上证指数的收盘价
    data_close.to_csv('000001.csv', index=False) #将下载的数据转存为.csv格式保存
    data_close = pd.read_csv('000001.csv')  # 读取文件

    df_sh = ts.get_k_data('sh', start='2019-01-01', end=datetime.datetime.now().strftime('%Y-%m-%d'))
    print(df_sh.shape)

    data_close = data_close.astype('float32').values  # 转换数据类型
    plt.plot(data_close)
    plt.savefig('data.png', format='png', dpi=200)
    plt.close()

    # 将价格标准化到0~1
    max_value = np.max(data_close)
    min_value = np.min(data_close)
    data_close = (data_close - min_value) / (max_value - min_value)

    # dataset_x
    # 是形状为(样本数, 时间窗口大小)
    # 的二维数组,用于训练模型的输入
    # dataset_y
    # 是形状为(样本数, )
    # 的一维数组,用于训练模型的输出。
    dataset_x, dataset_y = create_dataset(data_close, DAYS_FOR_TRAIN)  # 分别是(1007,10,1)(1007,1)

    # 划分训练集和测试集,70%作为训练集
    train_size = int(len(dataset_x) * 0.7)

    train_x = dataset_x[:train_size]
    train_y = dataset_y[:train_size]

    # 将数据改变形状,RNN 读入的数据维度是 (seq_size, batch_size, feature_size)
    train_x = train_x.reshape(-1, 1, DAYS_FOR_TRAIN)
    train_y = train_y.reshape(-1, 1, 1)

    # 转为pytorch的tensor对象
    train_x = torch.from_numpy(train_x)
    train_y = torch.from_numpy(train_y)

    model = LSTM_Regression(DAYS_FOR_TRAIN, 8, output_size=1, num_layers=2)  # 导入模型并设置模型的参数输入输出层、隐藏层等

    model_total = sum([param.nelement() for param in model.parameters()])  # 计算模型参数
    print("Number of model_total parameter: %.8fM" % (model_total / 1e6))

    train_loss = []
    loss_function = nn.MSELoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-2, betas=(0.9, 0.999), eps=1e-08, weight_decay=0)
    for i in range(200):
        out = model(train_x)
        loss = loss_function(out, train_y)
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        train_loss.append(loss.item())

        # 将训练过程的损失值写入文档保存,并在终端打印出来
        with open('log.txt', 'a+') as f:
            f.write('{} - {}\n'.format(i + 1, loss.item()))
        if (i + 1) % 1 == 0:
            print('Epoch: {}, Loss:{:.5f}'.format(i + 1, loss.item()))

    # 画loss曲线
    plt.figure()
    plt.plot(train_loss, 'b', label='loss')
    plt.title("Train_Loss_Curve")
    plt.ylabel('train_loss')
    plt.xlabel('epoch_num')
    plt.savefig('loss.png', format='png', dpi=200)
    plt.close()

    # torch.save(model.state_dict(), 'model_params.pkl')  # 可以保存模型的参数供未来使用
    t1 = time.time()
    T = t1 - t0
    print('The training time took %.2f' % (T / 60) + ' mins.')

    tt0 = time.asctime(time.localtime(t0))
    tt1 = time.asctime(time.localtime(t1))
    print('The starting time was ', tt0)
    print('The finishing time was ', tt1)




    # for test
    model = model.eval()  # 转换成评估模式
    # model.load_state_dict(torch.load('model_params.pkl'))  # 读取参数

    # 注意这里用的是全集 模型的输出长度会比原数据少DAYS_FOR_TRAIN 填充使长度相等再作图
    dataset_x = dataset_x.reshape(-1, 1, DAYS_FOR_TRAIN)  # (seq_size, batch_size, feature_size)
    dataset_x = torch.from_numpy(dataset_x)

    pred_test = model(dataset_x)  # 全量训练集
    # 的模型输出 (seq_size, batch_size, output_size)
    pred_test = pred_test.view(-1).data.numpy()
    pred_test = np.concatenate((np.zeros(DAYS_FOR_TRAIN), pred_test))  # 填充0 使长度相同
    assert len(pred_test) == len(data_close)

    plt.plot(pred_test, 'r', label='prediction')
    plt.plot(data_close, 'b', label='real')
    plt.plot((train_size, train_size), (0, 1), 'g--')  # 分割线 左边是训练数据 右边是测试数据的输出
    plt.legend(loc='best')
    plt.savefig('result.png', format='png', dpi=200)
    plt.close()
lstm使用教程,lstm,深度学习,机器学习,人工智能,Powered by 金山文档

2.4 小问题:为什么采用tanh函数,不能都用sigmoid函数吗

先放上两个函数的图形:

lstm使用教程,lstm,深度学习,机器学习,人工智能,Powered by 金山文档
  1. Sigmoid函数比Tanh函数收敛饱和速度慢

  1. Sigmoid函数比Tanh函数值域范围更窄

  1. tanh的均值是0,Sigmoid均值在0.5左右,均值在0的数据显然更便于数据处理

  1. tanh的函数变化敏感区间更大

  1. 对两者求导,发现tanh对计算的压力更小,直接是1-原函数的平方,不需要指数操作文章来源地址https://www.toymoban.com/news/detail-687736.html

使用该问的图请标明出处,创作不易,希望收获你的赞赞

到了这里,关于LSTM从入门到精通(形象的图解,详细的代码和注释,完美的数学推导过程)的文章就介绍完了。如果您还想了解更多内容,请在右上角搜索TOY模板网以前的文章或继续浏览下面的相关文章,希望大家以后多多支持TOY模板网!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处: 如若内容造成侵权/违法违规/事实不符,请点击违法举报进行投诉反馈,一经查实,立即删除!

领支付宝红包 赞助服务器费用

相关文章

  • 点云从入门到精通技术详解100篇-基于 LSTM 的点云降采样(下)

    目录 3.3 实验结果与分析 3.3.1 点云分类实验 3.3.2 点云检索实验 3.4 消融实验

    2024年02月07日
    浏览(43)
  • Spring Boot进阶(94):从入门到精通:Spring Boot和Prometheus监控系统的完美结合

      随着云原生技术的发展,监控和度量也成为了不可或缺的一部分。Prometheus 是一款最近比较流行的开源时间序列数据库,同时也是一种监控方案。它具有极其灵活的查询语言、自身的数据采集和存储机制以及易于集成的特点。而 Spring Boot 是一款快速构建应用的框架,其提

    2024年02月08日
    浏览(43)
  • MVSNet代码超详细注释(PyTorch)

    网上解读MVSNet的博客已经很多了,大家可以自选学习,但更重要的是阅读理解原文,以及自己动手跑跑代码! MVSNet服务器环境配置及测试 https://blog.csdn.net/qq_43307074/article/details/128011842 【论文简述及翻译】MVSNet:Depth Inference for Unstructured Multi-view Stereo(ECCV 2018) https://blog.csd

    2024年02月02日
    浏览(49)
  • DenseNet代码复现+超详细注释(PyTorch)

    关于DenseNet的原理和具体细节,可参见上篇解读:经典神经网络论文超详细解读(六)——DenseNet学习笔记(翻译+精读+代码复现) 接下来我们就来复现一下代码。 整个DenseNet模型主要包含三个核心细节结构,分别是 DenseLayer (整个模型最基础的原子单元,完成一次最基础的

    2023年04月23日
    浏览(52)
  • ResNet代码复现+超详细注释(PyTorch)

    关于ResNet的原理和具体细节,可参见上篇解读:经典神经网络论文超详细解读(五)——ResNet(残差网络)学习笔记(翻译+精读+代码复现) 接下来我们就来复现一下代码。 源代码比较复杂,感兴趣的同学可以上官网学习:  https://github.com/pytorch/vision/tree/master/torchvision 本

    2024年02月11日
    浏览(42)
  • ResNeXt代码复现+超详细注释(PyTorch)

    ResNeXt就是一种典型的混合模型,由基础的Inception+ResNet组合而成,本质在gruops分组卷积,核心创新点就是用一种平行堆叠相同拓扑结构的blocks代替原来 ResNet 的三层卷积的block,在不明显增加参数量级的情况下提升了模型的准确率,同时由于拓扑结构相同,超参数也减少了,便

    2024年02月15日
    浏览(50)
  • 【算法】顺时针打印矩阵(图文详解,代码详细注释

    目录 题目 代码如下: 输入一个矩阵,按照从外向里以顺时针的顺序依次打印出每一个数字。例如:如果输入如下矩阵: 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 则打印出数字:1 2 3 4 8 12 16 15 14 13 9 5 6 7 11 10 这一道题乍一看,没有包含任何复杂的数据结构和高级算法,似乎蛮简单的。但

    2024年04月26日
    浏览(37)
  • 1146 Topological Order(31行代码+详细注释)

    分数 25 全屏浏览题目 作者 CHEN, Yue 单位 浙江大学 This is a problem given in the Graduate Entrance Exam in 2018: Which of the following is NOT a topological order obtained from the given directed graph? Now you are supposed to write a program to test each of the options. Input Specification: Each input file contains one test case. For each

    2024年02月06日
    浏览(77)
  • 非极大值抑制详细原理(NMS含代码及详细注释)

    作者主页: 爱笑的男孩。的博客_CSDN博客-深度学习,YOLO,活动领域博主 爱笑的男孩。擅长深度学习,YOLO,活动,等方面的知识,爱笑的男孩。关注算法,python,计算机视觉,图像处理,深度学习,pytorch,神经网络,opencv领域. https://blog.csdn.net/Code_and516?type=collect 个人介绍:打工人。 分享内容

    2023年04月21日
    浏览(55)
  • 五子棋小游戏 java版(代码+详细注释)

    游戏展示         这周闲来无事,再来写个五子棋小游戏。基本功能都实现了,包括人人对战、人机对战。界面布局和功能都写的还行,没做到很优秀,但也不算差。如有需要,做个java初学者的课程设计或者自己写着玩玩也都是不错的(非常简单,小白照着就能写出来)。

    2024年02月07日
    浏览(47)

觉得文章有用就打赏一下文章作者

支付宝扫一扫打赏

博客赞助

微信扫一扫打赏

请作者喝杯咖啡吧~博客赞助

支付宝扫一扫领取红包,优惠每天领

二维码1

领取红包

二维码2

领红包