pytorch lightning 入门

这篇具有很好参考价值的文章主要介绍了pytorch lightning 入门。希望对大家有所帮助。如果存在错误或未考虑完全的地方,请大家不吝赐教,您也可以点击"举报违法"按钮提交疑问。

15分钟了解Pytorch Lightning

翻译自官方文档

前置知识:推荐pytorch
目标:通过PL中7个关键步骤了解PL工作流程

PL是基于pytorch的高层API,自带丰富的工具为AI学者和工程师快速创建高性能模型,去除繁琐的重复流程同时保持灵活性。

使用组织好的pytorch代码,PL可以:

  1. 避免重复流程。比如gpu设置,device设置,backward()等。
  2. 高可读性并更容易复现。PL将代码按照运行周期设置了不同HOOK,你可以很容易地找到关键代码。然后将其它方法迁移复现。比如损失计算在training_step()方法中。
  3. 简单设置即可使用多GPU策略。只需对Trainer配置设置和策略,无需在模型部分操作。
  4. 快速的test流程。无需额外写测试逻辑,只需和train流程类似,实现test_step(),内部完成推理和保存。调用trainer.test()即可完成测试流程。

第一步 安装PL

pip 安装

pip install lightning

conda 安装

conda install lightning -c conda-forge

第二步 定义一个LightningModule子类

子类中可以使用原生pytorch nn.Module创建的模块搭建PL的模型,然后在training_step()方法中实现损失计算的过程。将LightningModule的子类传入Trainer后,即会自动调用计算损失并反向传播。

# 导入包
import os
from torch import optim, nn, utils, Tensor
from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor
import lightning.pytorch as pl

# 定义一个线性层和激活函数组成的编码器-解码器,你可以使用任何nn.Module创建的模块
encoder = nn.Sequential(nn.Linear(28 * 28, 64), nn.ReLU(), nn.Linear(64, 3))
decoder = nn.Sequential(nn.Linear(3, 64), nn.ReLU(), nn.Linear(64, 28 * 28))


# 定义一个继承LightningModule的子类
class LitAutoEncoder(pl.LightningModule):
    def __init__(self, encoder, decoder):
        super().__init__()
        # 设置属性,模型包含两个子模块
        self.encoder = encoder
        self.decoder = decoder

    # 单步训练过程,里面包含forward流程,可以单独写出
    def training_step(self, batch, batch_idx):
        # training_step defines the train loop.
        # it is independent of forward
        x, y = batch
        x = x.view(x.size(0), -1)
        z = self.encoder(x)
        x_hat = self.decoder(z)
        # 计算损失
        loss = nn.functional.mse_loss(x_hat, x)
        # 记录训练过程中的损失。如果要使用tensorboard,wandb等工具,需要在trainer创建时指定,并如下设置logger=Ture
        self.log("train_loss", loss, logger=Ture)
        return loss
    # 配置优化器
    def configure_optimizers(self):
        optimizer = optim.Adam(self.parameters(), lr=1e-3)
        return optimizer

# 创建一个模型
autoencoder = LitAutoEncoder(encoder, decoder)

第三步 定义一个数据集加载器

任何可迭代iterable 的对象(list, dataloader, dict, numpy等)均可作为加载器。

# 这里使用的torchvision中包含的MNIST数据集,可以自定义,方法和pytorch一样实现get_item()方法。
dataset = MNIST(os.getcwd(), download=True, transform=ToTensor())
train_loader = utils.data.DataLoader(dataset)

第四步 训练模型

创建一个训练管理器trainer,调用fit()方法时传入模型和数据集加载器,开始训练。

# 创建Trainer时有不同参数,详见API说明
trainer = pl.Trainer(limit_train_batches=100, max_epochs=1)
# 传入模型和数据集加载器并开始训练
trainer.fit(model=autoencoder, train_dataloaders=train_loader)

Trainer 包含train/validation/test loop,日志记录、gpu设置、模型保存策略
等逻辑,避免花费过多精力在重复的流程上。

第五步 使用模型

训练完成后可以将模型转换格式部署至生产环境中,也只可加载权重预测。

# 加载checkpoint
checkpoint = "./lightning_logs/version_0/checkpoints/epoch=0-step=100.ckpt"
autoencoder = LitAutoEncoder.load_from_checkpoint(checkpoint, encoder=encoder, decoder=decoder)

# choose your trained nn.Module
encoder = autoencoder.encoder
encoder.eval()

# 预测
fake_image_batch = Tensor(4, 28 * 28)
embeddings = encoder(fake_image_batch)
print("⚡" * 20, "\nPredictions (4 image embeddings):\n", embeddings, "\n", "⚡" * 20)

第六步 训练监控

如果创建Trainer时指定了logger,如tensorboard, wandb并给定相关设置,即可打开日志文件查看训练状态。比如使用了tensorboard作为logger,执行以下命令后,浏览器打开 http://localhost:6006/

cd path/to/your/log/file
tensorboard --logdir .

第七步 训练的高级功能

在创建Trainer时你可以使用不同的设备(cpu、gpu、tpu等)、不同的策略、精度,以及callback。

# 4块GPU上训练
trainer = Trainer(
    devices=4,
    accelerator="gpu",
 )

# 通过Deepspeed/fsdp策略,在16精度下由4块GPU训练
trainer = Trainer(
    devices=4,
    accelerator="gpu",
    strategy="deepspeed_stage_2",
    precision=16
 )

# Trainer 中20多个有用的配置项,比如
trainer = Trainer(
    max_epochs=10,# 最大训练10个epochs
    min_epochs=5,# 最小训练5个epoch
    overfit_batches=0.01 # 使用0.01的数据训练,以快速测试代码,默认为1
 )

# 将实现了callback接口的模块整合进流程中,callback中不同流程的处理逻辑
trainer = Trainer(callbacks=[StochasticWeightAveraging(...)])

参考文档

  1. Pytorch Lightning

个人记录
https://github.com/Githubwujinming/LightningInstruction文章来源地址https://www.toymoban.com/news/detail-403046.html

到了这里,关于pytorch lightning 入门的文章就介绍完了。如果您还想了解更多内容,请在右上角搜索TOY模板网以前的文章或继续浏览下面的相关文章,希望大家以后多多支持TOY模板网!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处: 如若内容造成侵权/违法违规/事实不符,请点击违法举报进行投诉反馈,一经查实,立即删除!

领支付宝红包 赞助服务器费用

相关文章

  • PyTorch Lightning教程七:可视化

    本节指导如何利用Lightning进行可视化和监控模型 为何需要跟踪参数 在模型开发中,我们跟踪感兴趣的值,例如validation_loss,以可视化模型的学习过程。模型开发就像驾驶一辆没有窗户的汽车,图表和日志提供了窗口,让我们知道该把车开到哪里。有了Lightning,几乎可以可视

    2024年02月14日
    浏览(39)
  • PyTorch Lightning教程八:用模型预测,部署

    关于Checkpoints的内容在教程2里已经有了详细的说明,在本节,需要用它来利用模型进行预测 加载checkpoint并预测 使用模型进行预测的最简单方法是使用LightningModule中的load_from_checkpoint加载权重。 predict_step方法 加载检查点并进行预测仍然会在预测阶段的epoch留下许多boilerplate,

    2024年02月12日
    浏览(39)
  • PyTorch Lightning教程四:超参数的使用

    如果需要和命令行接口进行交互,可以使用Python中的argparse包,快捷方便,对于Lightning而言,可以利用它,在命令行窗口中,直接配置超参数等操作,但也可以使用LightningCLI的方法,更加轻便简单。 ArgumentParser ArgumentParser是Python的内置特性,进而构建CLI程序,我们可以使用它

    2024年02月15日
    浏览(35)
  • PyTorch Lightning教程二:验证、测试、checkpoint、早停策略

    介绍:上一期介绍了如何利用PyTorch Lightning搭建并训练一个模型(仅使用训练集),为了保证模型可以泛化到未见过的数据上,数据集通常被分为训练和测试两个集合,测试集与训练集相互独立,用以测试模型的泛化能力。本期通过增加验证和测试集来达到该目的,同时,还

    2024年02月16日
    浏览(34)
  • 变分自编码器(VAE)PyTorch Lightning 实现

    ✅作者简介:人工智能专业本科在读,喜欢计算机与编程,写博客记录自己的学习历程。 🍎个人主页:小嗷犬的个人主页 🍊个人网站:小嗷犬的技术小站 🥭个人信条:为天地立心,为生民立命,为往圣继绝学,为万世开太平。 变分自编码器 (Variational Autoencoder,VAE)是一

    2024年02月21日
    浏览(51)
  • (5)深度学习学习笔记-多层感知机-pytorch lightning版

    pytorch lighting是导师推荐给我学习的一个轻量级的PyTorch库,代码干净简洁,使用pl更容易理解ML代码,对于初学者的我还是相对友好的。 pytorch lightning官网网址 https://lightning.ai/docs/pytorch/stable/levels/core_skills.html 代码如下: 代码如下:(可以直接把download改为true下载) 更多pl的方

    2024年02月12日
    浏览(44)
  • 版本匹配指南:PyTorch版本、Python版本和pytorch_lightning版本的对应关系

    版本匹配指南:PyTorch版本、Python版本和pytorch_lightning版本的对应关系 🌈 欢迎莅临 我的个人主页👈这里是我 静心耕耘 深度学习领域、 真诚分享 知识与智慧的小天地!🎇 🎓 博主简介: 我是 高斯小哥 ,一名来自985高校的普通本硕生,曾有幸在中科院顶刊发表过 一作论文

    2024年04月17日
    浏览(65)
  • PyTorch Lightning快速学习教程一:快速训练一个基础模型

    粉丝量突破1200了!找到了喜欢的岗位,毕业上班刚好也有20天,为了督促自己终身学习的态度,继续开始坚持写写博客,沉淀并总结知识! 介绍:PyTorch Lightning是针对科研人员、机器学习开发者专门设计的,能够快速复用代码的一个工具,避免了因为每次都编写相似的代码而

    2024年02月16日
    浏览(55)
  • 关于安装 PyTorch-Lightning 的一些问题(GPU版)

    官网地址: PyTorch PyTorch-Lightning 1、不能直接使用 pip install pytorch-lightning  ,否则如下图会直接卸载掉你的torch而安装cpu版本的torch。 2、在安装pytorch-lightning时一定注意自己的torch是pip安装还是conda安装,两者要保持一致,否则也会导致你的torch版本被替换。 正确安装方式: p

    2024年02月02日
    浏览(33)
  • PyTorch Lightning:通过分布式训练扩展深度学习工作流

              欢迎来到我们关于 PyTorch Lightning 系列的第二篇文章!在上一篇文章中,我们向您介绍了 PyTorch Lightning,并探讨了它在简化深度学习模型开发方面的主要功能和优势。我们了解了 PyTorch Lightning 如何为组织和构建 PyTorch 代码提供高级抽象,使研究人员和从业者能够

    2024年02月11日
    浏览(45)

觉得文章有用就打赏一下文章作者

支付宝扫一扫打赏

博客赞助

微信扫一扫打赏

请作者喝杯咖啡吧~博客赞助

支付宝扫一扫领取红包,优惠每天领

二维码1

领取红包

二维码2

领红包