[PyTorch][chapter 46][LSTM -1]

这篇具有很好参考价值的文章主要介绍了[PyTorch][chapter 46][LSTM -1]。希望对大家有所帮助。如果存在错误或未考虑完全的地方,请大家不吝赐教,您也可以点击"举报违法"按钮提交疑问。

前言:

           长短期记忆网络(LSTM,Long Short-Term Memory)是一种时间循环神经网络,是为了解决一般的RNN(循环神经网络)存在的长期依赖问题而专门设计出来的。

目录:

  1.      背景简介
  2.      LSTM Cell
  3.      LSTM 反向传播算法
  4.      为什么能解决梯度消失
  5.       LSTM 模型的搭建

一  背景简介:

       1.1  RNN

         RNN 忽略 模型可以简化成如下

      [PyTorch][chapter 46][LSTM -1],lstm,人工智能,rnn

       

          图中Rnn Cell 可以很清晰看出在隐藏状态。

            得到 后:

              一方面用于当前层的模型损失计算,另一方面用于计算下一层的[PyTorch][chapter 46][LSTM -1],lstm,人工智能,rnn

    由于RNN梯度消失的问题,后来通过LSTM 解决 

       1.2 LSTM 结构

        [PyTorch][chapter 46][LSTM -1],lstm,人工智能,rnn


二  LSTM  Cell

   LSTMCell(RNNCell) 结构

          [PyTorch][chapter 46][LSTM -1],lstm,人工智能,rnn

          前向传播算法 Forward

         2.1   更新: forget gate 忘记门

             [PyTorch][chapter 46][LSTM -1],lstm,人工智能,rnn

             将值朝0 减少, 激活函数一般用sigmoid

             输出值[0,1]

         2.2 更新: Input gate 输入门

                [PyTorch][chapter 46][LSTM -1],lstm,人工智能,rnn

                决定是不是忽略输入值

    

           2.3 更新: 候选记忆单元

                    [PyTorch][chapter 46][LSTM -1],lstm,人工智能,rnn

           2.4 更新: 记忆单元

               [PyTorch][chapter 46][LSTM -1],lstm,人工智能,rnn

             2.5  更新: 输出门

                决定是否使用隐藏值

                 [PyTorch][chapter 46][LSTM -1],lstm,人工智能,rnn  

           2.6. 隐藏状态

                

           2.7  模型输出

                  [PyTorch][chapter 46][LSTM -1],lstm,人工智能,rnn

LSTM 门设计的解释一:

 输入门 ,遗忘门,输出门 不同取值组合的时候,记忆单元的输出情况

[PyTorch][chapter 46][LSTM -1],lstm,人工智能,rnn


三  LSTM 反向传播推导

      3.1 定义两个

             

            

    3.2  定义损失函数

            损失函数分为两部分: 

             时刻t的损失函数 

             时刻t后的损失函数[PyTorch][chapter 46][LSTM -1],lstm,人工智能,rnn

              [PyTorch][chapter 46][LSTM -1],lstm,人工智能,rnn

      3.3 最后一个时刻的

              [PyTorch][chapter 46][LSTM -1],lstm,人工智能,rnn

 这里面要注意这里的[PyTorch][chapter 46][LSTM -1],lstm,人工智能,rnn

    证明一下第二项,主要应用到微分的两个性质,以及微分和迹的关系:

   [PyTorch][chapter 46][LSTM -1],lstm,人工智能,rnn

     ... 公式1: 微分和迹的关系

       

     因为

    

   

           

     带入上面公式1:

      

           

    所以

[PyTorch][chapter 46][LSTM -1],lstm,人工智能,rnn

3.4   链式求导过程

       求导结果:

 [PyTorch][chapter 46][LSTM -1],lstm,人工智能,rnn

  这里详解一下推导过程:

  这是一个符合函数求导:先把h 写成向量形成

 ------------------------------------------------------------   

 第一项: 

             

         [PyTorch][chapter 46][LSTM -1],lstm,人工智能,rnn

         [PyTorch][chapter 46][LSTM -1],lstm,人工智能,rnn

        设 [PyTorch][chapter 46][LSTM -1],lstm,人工智能,rnn

           则    [PyTorch][chapter 46][LSTM -1],lstm,人工智能,rnn

 

            其中:(利用矩阵求导的定义法 分子布局原理)

                    [PyTorch][chapter 46][LSTM -1],lstm,人工智能,rnn 是一个对角矩阵

                  

                 [PyTorch][chapter 46][LSTM -1],lstm,人工智能,rnn

                 [PyTorch][chapter 46][LSTM -1],lstm,人工智能,rnn

                 几个连乘起来就是第一项

               [PyTorch][chapter 46][LSTM -1],lstm,人工智能,rnn

第二项

    [PyTorch][chapter 46][LSTM -1],lstm,人工智能,rnn

   [PyTorch][chapter 46][LSTM -1],lstm,人工智能,rnn

   [PyTorch][chapter 46][LSTM -1],lstm,人工智能,rnn

  [PyTorch][chapter 46][LSTM -1],lstm,人工智能,rnn

参考:

   

其中:

[PyTorch][chapter 46][LSTM -1],lstm,人工智能,rnn

[PyTorch][chapter 46][LSTM -1],lstm,人工智能,rnn

 [PyTorch][chapter 46][LSTM -1],lstm,人工智能,rnn

 [PyTorch][chapter 46][LSTM -1],lstm,人工智能,rnn

其它也是相似,就有了上面的求导结果

[PyTorch][chapter 46][LSTM -1],lstm,人工智能,rnn


四  为什么能解决梯度消失

    [PyTorch][chapter 46][LSTM -1],lstm,人工智能,rnn

     4.1 RNN 梯度消失的原理

                ,复旦大学邱锡鹏书里面 有更加详细的解释,通过极大假设:

在梯度计算中存在梯度的k 次方连乘 ,导致 梯度消失原理。

    4.2  LSTM 解决梯度消失 解释1:

            通过上面公式发现梯度计算中是加法运算,不存在连乘计算,

            极大概率降低了梯度消失的现象。

    4.3  LSTM 解决梯度 消失解释2:

              记忆单元c  作用相当于ResNet的残差部分.  

   比如 时候,,不会存在梯度消失。

       


五 模型的搭建

[PyTorch][chapter 46][LSTM -1],lstm,人工智能,rnn

   [PyTorch][chapter 46][LSTM -1],lstm,人工智能,rnn

    我们最后发现:

     的维度必须一致,都是hidden_size

    通过,则  最后一个维度也必须是hidden_size

    

# -*- coding: utf-8 -*-
"""
Created on Thu Aug  3 15:11:19 2023

@author: chengxf2
"""

# -*- coding: utf-8 -*-
"""
Created on Wed Aug  2 15:34:25 2023

@author: chengxf2
"""

import torch
from torch import nn
from d21 import torch as d21


def normal(shape,devices):
    
    data = torch.randn(size= shape, device=devices)*0.01
    
    return data


def get_lstm_params(input_size, hidden_size,categorize_size,devices):
    


    
    #隐藏门参数
    W_xf= normal((input_size, hidden_size), devices)
    W_hf = normal((hidden_size, hidden_size),devices)
    b_f = torch.zeros(hidden_size,devices)
    
    #输入门参数
    W_xi= normal((input_size, hidden_size), devices)
    W_hi = normal((hidden_size, hidden_size),devices)
    b_i = torch.zeros(hidden_size,devices)
    

    
    #输出门参数
    W_xo= normal((input_size, hidden_size), devices)
    W_ho = normal((hidden_size, hidden_size),devices)
    b_o = torch.zeros(hidden_size,devices)
    
    #临时记忆单元
    W_xc= normal((input_size, hidden_size), devices)
    W_hc = normal((hidden_size, hidden_size),devices)
    b_c = torch.zeros(hidden_size,devices)
    
    #最终分类结果参数
    W_hq = normal((hidden_size, categorize_size), devices)
    b_q = torch.zeros(categorize_size,devices)
    
    
    params =[
        W_xf,W_hf,b_f,
        W_xi,W_hi,b_i,
        W_xo,W_ho,b_o,
        W_xc,W_hc,b_c,
        W_hq,b_q]
    
    for param in params:
        
        param.requires_grad_(True)
        
    return params

def init_lstm_state(batch_size, hidden_size, devices):
    
    cell_init = torch.zeros((batch_size, hidden_size),device=devices)
    hidden_init = torch.zeros((batch_size, hidden_size),device=devices)
    
    return (cell_init, hidden_init)


def lstm(inputs, state, params):
    [
        W_xf,W_hf,b_f,
        W_xi,W_hi,b_i,
        W_xo,W_ho,b_o,
        W_xc,W_hc,b_c,
        W_hq,b_q] = params    
    
    (H,C) = state
    outputs= []
    
    for x in inputs:
        
        #input gate
        I = torch.sigmoid((x@W_xi)+(H@W_hi)+b_i)
        F = torch.sigmoid((x@W_xf)+(H@W_hf)+b_f)
        O = torch.sigmoid((x@W_xo)+(H@W_ho)+b_o)
        
        C_tmp = torch.tanh((x@W_xc)+(H@W_hc)+b_c)
        C = F*C+I*C_tmp
        
        H = O*torch.tanh(C)
        Y = (H@W_hq)+b_q
        
        outputs.append(Y)
        
    return torch.cat(outputs, dim=0),(H,C)
        
    

def main():
    batch_size,num_steps =32, 35
    train_iter, cocab= d21.load_data_time_machine(batch_size, num_steps)

    

if __name__ == "__main__":
    
     main()

 参考

 

CSDN

https://www.cnblogs.com/pinard/p/6519110.html

57 长短期记忆网络(LSTM)【动手学深度学习v2】_哔哩哔哩_bilibili文章来源地址https://www.toymoban.com/news/detail-636475.html

到了这里,关于[PyTorch][chapter 46][LSTM -1]的文章就介绍完了。如果您还想了解更多内容,请在右上角搜索TOY模板网以前的文章或继续浏览下面的相关文章,希望大家以后多多支持TOY模板网!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处: 如若内容造成侵权/违法违规/事实不符,请点击违法举报进行投诉反馈,一经查实,立即删除!

领支付宝红包 赞助服务器费用

相关文章

  • Pytorch的CNN,RNN&LSTM

    拿二维卷积举例,我们先来看参数 卷积的基本原理,默认你已经知道了,然后我们来解释pytorch的各个参数,以及其背后的计算过程。 首先我们先来看卷积过后图片的形状的计算: 参数: kernel_size :卷积核的大小,可以是一个元组,也就是(行大小,列大小) stride : 移动步长

    2024年02月04日
    浏览(50)
  • PyTorch训练RNN, GRU, LSTM:手写数字识别

    数据集:MNIST 该数据集的内容是手写数字识别,其分为两部分,分别含有60000张训练图片和10000张测试图片 图片来源:https://tensornews.cn/mnist_intro/ 神经网络:RNN, GRU, LSTM 【1】https://www.youtube.com/watch?v=Gl2WXLIMvKAlist=PLhhyoLH6IjfxeoooqP9rhU3HJIAVAJ3Vzindex=5

    2024年02月15日
    浏览(40)
  • 【PyTorch API】 nn.RNN 和 nn.LSTM 介绍和代码详解

    torch.nn.RNN 的 PyTorch 链接:torch.nn.RNN(*args, **kwargs) nn.RNN 的用法和输入输出参数的介绍直接看代码: 需要特别注意的是 nn.RNN 的第二个输出 hn 表示所有掩藏层的在最后一个 time step 隐状态,听起来很难理解,看下面的红色方框内的数据就懂了。即 output[:, -1, :] = hn[-1, : , :] 这里

    2024年02月12日
    浏览(39)
  • Pytorch 对比TensorFlow 学习:Day 17-18: 循环神经网络(RNN)和LSTM

    Day 17-18: 循环神经网络(RNN)和LSTM 在这两天的学习中,我专注于理解循环神经网络(RNN)和长短期记忆网络(LSTM)的基本概念,并学习了它们在处理序列数据时的应用。 1.RNN和LSTM基础: RNN:了解了RNN是如何处理序列数据的,特别是它的循环结构可以用于处理时间序列或连续

    2024年01月20日
    浏览(60)
  • Python深度学习026:基于Pytorch的典型循环神经网络模型RNN、LSTM、GRU的公式及简洁案例实现(官方)

    循环神经网络(也有翻译为递归神经网络)最典型的三种网络结构是: RNN(Recurrent Neural Network,循环神经网络) LSTM(Long Short-Term Memory,长短期记忆网络) GRU(Gate Recurrent Unit,门控循环单元) 理解参数的含义非常重要,否则,你不知道准备什么维度的输入数据送入模型 先

    2023年04月22日
    浏览(37)
  • LSTM已死,Transformer永生(面试问答RNN/LSTM/Transformer)

    计算机视觉面试题-Transformer相关问题总结 :https://zhuanlan.zhihu.com/p/554814230 计算机视觉面试31题 CV面试考点,精准详尽解析 :https://zhuanlan.zhihu.com/p/257883797 RNN的出现是为了解决输入输出没有严格对应关系的问题,该模型的输入和输出长度不需固定。 RNN的框架结构如下, X0, X1

    2024年02月12日
    浏览(39)
  • RNN&LSTM

    LSTM——起源、思想、结构 与“门” 完全图解RNN、RNN变体、Seq2Seq、Attention机制 完全解析RNN, Seq2Seq, Attention注意力机制 Sequence to sequence入门详解:从RNN, LSTM到Encoder-Decoder, Attention, transformer 从RNN到Attention到Transformer系列-Attention介绍及代码实现 提示:这里可以添加本文要记录的大概

    2024年02月16日
    浏览(38)
  • RNN & LSTM

    参考资料: 《机器学习2022》李宏毅 史上最详细循环神经网络讲解(RNN/LSTM/GRU) - 知乎 (zhihu.com) LSTM如何来避免梯度弥散和梯度爆炸? - 知乎 (zhihu.com) 首先考虑这样一个 slot filling 问题: 注意到,上图中 Taipei 的输出为 destination。如果我们只是单纯地将每个词向量输入到一个

    2024年02月16日
    浏览(36)
  • RNN和LSTM的区别是什么?

    RNN(循环神经网络)和LSTM(长短时记忆网络)都是处理序列数据(如时间序列或文本)的神经网络类型,但它们在结构和功能上有一些关键区别: 1. 基本结构: RNN: RNN的核心是一个循环单元,它在序列的每个时间步上执行相同的任务,同时保留一些关于之前步骤的信息。RNN的

    2024年01月21日
    浏览(37)
  • 相比于rnn, lstm有什么优势

    相对于常规的循环神经网络(RNN),长短期记忆网络(LSTM)具有以下优势: 处理长期依赖性:LSTM通过引入记忆单元和门控机制来解决传统RNN中的梯度消失和梯度爆炸问题。LSTM能够更好地捕捉时间序列数据中的长期依赖关系,使其在处理长序列和长期依赖性任务时表现更出色

    2024年01月20日
    浏览(37)

觉得文章有用就打赏一下文章作者

支付宝扫一扫打赏

博客赞助

微信扫一扫打赏

请作者喝杯咖啡吧~博客赞助

支付宝扫一扫领取红包,优惠每天领

二维码1

领取红包

二维码2

领红包