Pytorch 的 LSTM 模型的简单示例

这篇具有很好参考价值的文章主要介绍了Pytorch 的 LSTM 模型的简单示例。希望对大家有所帮助。如果存在错误或未考虑完全的地方,请大家不吝赐教,您也可以点击"举报违法"按钮提交疑问。

1. 代码

完整的源代码:

import torch
from torch import nn

# 定义一个LSTM模型
class LSTM(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers, output_size):
        super(LSTM, self).__init__()
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True)
        self.fc = nn.Linear(hidden_size, output_size)

    def forward(self, x):
        # 初始化隐藏状态h0, c0为全0向量
        h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(x.device)
        c0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(x.device)

        # 将输入x和隐藏状态(h0, c0)传入LSTM网络
        out, _ = self.lstm(x, (h0, c0))
        # 取最后一个时间步的输出作为LSTM网络的输出
        out = self.fc(out[:, -1, :])
        return out

# 定义LSTM超参数
input_size = 10   # 输入特征维度
hidden_size = 32  # 隐藏单元数量
num_layers = 2    # LSTM层数
output_size = 2   # 输出类别数量

# 构建一个随机输入x和对应标签y
x = torch.randn(64, 5, 10)  # [batch_size, sequence_length, input_size]
y = torch.randint(0, 2, (64,))  # 二分类任务,标签为0或1

# 创建LSTM模型,并将输入x传入模型计算预测输出
lstm = LSTM(input_size, hidden_size, num_layers, output_size)
pred = lstm(x)  # [batch_size, output_size]

# 定义损失函数和优化器,并进行模型训练
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(lstm.parameters(), lr=1e-3)
num_epochs = 100

for epoch in range(num_epochs):
    # 前向传播计算损失函数值
    pred = lstm(x)  # 在每个epoch中重新计算预测输出
    loss = criterion(pred.squeeze(), y)

    # 反向传播更新模型参数
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    # 输出每个epoch的训练损失
    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}")

2. 模型结构分析

# 定义一个LSTM模型
class LSTM(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers, output_size):
        super(LSTM, self).__init__()
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True)
        self.fc = nn.Linear(hidden_size, output_size)

    def forward(self, x):
        # 初始化隐藏状态h0, c0为全0向量
        h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(x.device)
        c0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(x.device)

        # 将输入x和隐藏状态(h0, c0)传入LSTM网络
        out, _ = self.lstm(x, (h0, c0))
        # 取最后一个时间步的输出作为LSTM网络的输出
        out = self.fc(out[:, -1, :])
        return out

上述代码定义了一个LSTM类,这个类可以用于完成一个基于LSTM的序列模型的搭建。

在初始化函数中,输入的参数分别是输入数据的特征维度(input_size),隐藏层的大小(hidden_size),LSTM层数(num_layers)以及输出数据的维度(output_size)。这里使用batch_first=True表示输入数据的第一个维度是batch size,第二个维度是时间步长和特征维度。

在forward函数中,首先初始化了LSTM网络的隐藏状态为全0向量,并且将其移动到与输入数据相同的设备上。然后调用了nn.LSTM函数进行前向传播操作,并且通过fc层将最后一个时间步的输出映射为输出的数据,最后进行了返回。

3. 代码详解

        # 将输入x和隐藏状态(h0, c0)传入LSTM网络
        out, _ = self.lstm(x, (h0, c0))

这行代码是利用 PyTorch 自带的 LSTM 模块处理输入张量 x(形状为 [batch_size, sequence_length, input_size])并得到 LSTM 层的输出 out 和最终状态。其中,h0 是 LSTM 层的初始隐藏状态,c0 是 LSTM 层的初始细胞状态。

在代码中,调用了 self.lstm(x, (h0, c0)) 函数,该函数的返回值有两个:第一个返回值是 LSTM 层的输出 out,其包含了所有时间步上的隐状态;第二个返回值是一个元组,包含了最后一个时间步的隐藏状态和细胞状态,但我们用“_”丢弃了它。

因为对于许多深度学习任务来说,只需要输出序列的最后一个时间步的隐藏状态,而不需要每个时间步上的隐藏状态。因此,这里我们只保留 LSTM 层的输出 out,而忽略了 LSTM 层最后时间步的状态。

最后,out 的形状为 [batch_size, sequence_length, hidden_size],其中 hidden_size 是 LSTM 层输出的隐藏状态的维度大小。

x = torch.randn(64, 5, 10)

这行代码创建了一个形状为 (64, 5, 10) 的张量 x,它包含 64 个样本,每个样本具有 5 个特征维度和 10 个时间步。该张量的值是由均值为 0,标准差为 1 的正态分布随机生成的。

torch.randn() 是 PyTorch 中生成服从标准正态分布的随机数的函数。它的输入是张量的形状,输出是符合正态分布的张量。在本例中,形状为 (64, 5, 10) 表示该张量包含 64 个样本,每个样本包含 5 个特征维度和 10 个时间步,每个元素都是服从标准正态分布的随机数。这种方式生成的随机数可以用于初始化模型参数、生成噪音数据等许多深度学习应用场景。

y = torch.randint(0, 2, (64,))  # 二分类任务,标签为0或1

y = torch.randint(0, 2, (64,)) 是使用 PyTorch 库中的 randint() 函数来生成一个64个元素的张量 y,张量的每个元素都是从区间 [0, 2) 中随机生成的整数。

具体而言,torch.randint() 函数包含三个参数,分别是 low、high 和 size。其中,low 和 high 分别表示随机生成整数的区间为 [low, high),而 size 参数指定了生成的张量的形状。

在上述代码中,size=(64,) 表示生成的张量 y 的形状为 64x1,即一个包含 64 个元素的一维张量,并且每个元素的值都在 [0, 2) 中随机生成。这种形式的张量通常用于分类问题中的标签向量。在该任务中,一个标签通常由一个整数表示,因此可以采用使用 randint() 函数生成一个长度为标签类别数的一维张量,其每个元素的取值为 0 或 1,表示对应类别是否被选中。

# 创建LSTM模型,并将输入x传入模型计算预测输出
lstm = LSTM(input_size, hidden_size, num_layers, output_size)
pred = lstm(x)  # [batch_size, output_size]

通过定义的LSTM类创建了一个LSTM模型,并将输入x传入模型进行前向计算,得到了一个预测输出pred,其形状为[64, output_size],其中output_size是在LSTM初始化函数中指定的输出数据的维度。

这段代码演示了如何使用已经构建好的代码搭建并训练一个基于LSTM的序列模型,并且展示了其中的一些关键步骤,包括数据输入、模型创建以及前向计算。文章来源地址https://www.toymoban.com/news/detail-462549.html

到了这里,关于Pytorch 的 LSTM 模型的简单示例的文章就介绍完了。如果您还想了解更多内容,请在右上角搜索TOY模板网以前的文章或继续浏览下面的相关文章,希望大家以后多多支持TOY模板网!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处: 如若内容造成侵权/违法违规/事实不符,请点击违法举报进行投诉反馈,一经查实,立即删除!

领支付宝红包 赞助服务器费用

相关文章

  • 【VAR | 时间序列】以美国 GDP 和通货膨胀数据为例的VAR模型简单实战(含Python源代码)

    以美国 GDP 和通货膨胀数据为例: 下载数据我们需要从 FRED 数据库下载美国 GDP 和通货膨胀数据,并将它们存储在 CSV 文件中。可以在 FRED 网站(https://fred.stlouisfed.org/)搜索并下载需要的数据。在这里,并且将它们命名为 ‘gdp.csv’ 和 ‘inflation.csv’。 网站为: 在搜索栏中输

    2024年02月02日
    浏览(101)
  • Tensorflow车牌识别完整项目(含完整源代码及训练集)

    基于TensorFlow的车牌识别系统设计与实现,运用tensorflow和OpenCV的相关技术,实现车牌的定位、车牌的二值化、车牌去噪增强、图片的分割,模型的训练和车牌的识别等 项目问题,毕设,大创可私聊博主 目录 环境准备 思路流程 功能描述 细节阐述 项目总体框架 过程展示 技术

    2024年02月02日
    浏览(47)
  • Python:实现图片叠加效果,附带完整源代码

    Python:实现图片叠加效果,附带完整源代码 在图像处理中,叠加图片是一种广泛应用且非常实用的技术。通过将两张或多张图片叠加在一起,可以达到更好的视觉效果。本文将介绍如何使用Python实现图片叠加功能,并提供完整的源代码。 首先需要安装所需的Python库——Pill

    2024年02月11日
    浏览(53)
  • 微信小程序 - 超详细小程序接入腾讯地图的完整流程,提供地图显示、IP 属地定位、地理位置名称、获取经纬度等超多功能示例(可一键复制并运行的功能源代码,详细的注释及常见问题汇总)小白直接上手!

    网上的教程代码太乱了,第一次接触的朋友极其难搞,更别说把功能改造移植到自己的项目中去。 本文站在小白的角度, 实现了微信小程序开发中,集成腾讯地图的详细流程及使用方法教程,提供了地图显示、IP 属地定位、当前定位的地理位置名称、当前定位的经纬度等等

    2024年02月16日
    浏览(60)
  • (纯) 基于JAVAWEB的网上购物平台(完整源代码)

    摘要        随着计算机网络技术的飞速发展和人们生活节奏的不断加快,电子商务技术已经逐渐融入了人们的日常生活当中,网上商城作为电子商务最普遍的一种形式,已被大众逐渐接受。因此开发一个网上商城系统,适合当今形势,更加方便人们在线购物。     本网上商

    2024年02月08日
    浏览(56)
  • 28个炫酷的CSS特效动画示例(含源代码)

    CSS是网页的三驾马车之一,是对页面布局的总管家,2024年了,这里列出28个超级炫酷的纯CSS动画示例,让您的网站更加炫目多彩。 效果图: 点击查看示例源代码 效果图: 点击查看示例源代码 效果图: 点击查看示例源代码 效果图: 点击查看示例源代码 效果图: 点击查看示

    2024年01月16日
    浏览(50)
  • PINN神经网络源代码解析(pyTorch)

    PINN(Physics-informed Neural Networks)的原理部分可参见https://maziarraissi.github.io/PINNs/ 考虑Burgers方程,如下图所示,初始时刻 u 符合 sin 分布,随着时间推移在 x=0 处发生间断. 这是一个经典问题,可使用 pytorch 通过PINN实现对Burgers方程的求解。 源代码共含有三个文件,来源于Github htt

    2024年02月12日
    浏览(99)
  • 28个炫酷的纯CSS特效动画示例(含源代码)

    CSS是网页的三驾马车之一,是对页面布局的总管家,2024年了,这里列出28个超级炫酷的纯CSS动画示例,让您的网站更加炫目多彩。 效果图: 点击查看示例源代码 效果图: 点击查看示例源代码 效果图: 点击查看示例源代码 效果图: 点击查看示例源代码 效果图: 点击查看示

    2024年01月20日
    浏览(54)
  • 基于python+mysql超市信息管理系统(附完整源代码)

    (参考的是这篇文章(5条消息) 数据库课程设计—超市零售信息管理系统(Python实现)_小桃在改bug的博客-CSDN博客_超市管理系统数据库设计但是这篇文章里没有完整的代码,所以我自己补全了ui界面和相关的代码,并进行了二创,框架也有改动,更主要的是写出来自己在编写过

    2024年02月03日
    浏览(49)
  • 服务端和客户端通信--UDP(含完整源代码)

    实验设备:     目标系统:Windows 软件工具:vs2022/vc6/dev   实验要求: 完成UDP服务端和客户端的程序编写; 分别实现UDP一对一通信和广播通信功能。 实验内容: -static-libgcc 一对一通信 : 1 、加载/释放Winsock库,创建套接字(WSAStartup()/socket())。 加载方法: WSADATA wsa; /*初始化

    2024年02月14日
    浏览(55)

觉得文章有用就打赏一下文章作者

支付宝扫一扫打赏

博客赞助

微信扫一扫打赏

请作者喝杯咖啡吧~博客赞助

支付宝扫一扫领取红包,优惠每天领

二维码1

领取红包

二维码2

领红包