PyTorch Lightning快速学习教程一:快速训练一个基础模型

这篇具有很好参考价值的文章主要介绍了PyTorch Lightning快速学习教程一:快速训练一个基础模型。希望对大家有所帮助。如果存在错误或未考虑完全的地方,请大家不吝赐教,您也可以点击"举报违法"按钮提交疑问。

粉丝量突破1200了!找到了喜欢的岗位,毕业上班刚好也有20天,为了督促自己终身学习的态度,继续开始坚持写写博客,沉淀并总结知识!
介绍:PyTorch Lightning是针对科研人员、机器学习开发者专门设计的,能够快速复用代码的一个工具,避免了因为每次都编写相似的代码而带来的时间成本。其可以理解为,lightning设计了一个,能够快速搭建训练验证测试模型的整套代码模板,我们只需要编写设计需要的模型、超参数、优化器等,直接套进去即可。lightning的优势在于:灵活性高、可读性强、支持多卡训练、内置测试、内置日志等。

前置掌握知识:Python和PyTorch的使用

链接:https://lightning.ai/

快速安装:pip install lightning

1.添加依赖包

需要添加相应的依赖,包括os,torch工具包,torch数据载入等依赖

import os		
import torch
from torch import nn
import torch.nn.functional as F
from torchvision import transforms
from torchvision.datasets import MNIST
from torch.utils.data import DataLoader
import lightning.pytorch as pl
2.定义模型

PyTorch定义模型案例如下,定义好了方便后续的调用

class Encoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.l1 = nn.Sequential(nn.Linear(28 * 28, 64), nn.ReLU(), nn.Linear(64, 3))
    def forward(self, x):
        return self.l1(x)	# 全连接 激活 全连接

class Decoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.l1 = nn.Sequential(nn.Linear(3, 64), nn.ReLU(), nn.Linear(64, 28 * 28))
    def forward(self, x):
        return self.l1(x)	# 全连接 激活 全连接
3.定义网络架构

定义网络模型,自定义模型名字,并继承lightning.pytorch.LightningModule类,如下代码

  • training_step定义了与nn.Module之间交互

  • configure_optimizers为模型定义优化器

class LitAutoEncoder(pl.LightningModule):
    def __init__(self, encoder, decoder):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder

    def training_step(self, batch, batch_idx):
        # training_step defines the train loop.
        x, y = batch
        x = x.view(x.size(0), -1)
        z = self.encoder(x)
        x_hat = self.decoder(z)
        loss = F.mse_loss(x_hat, x)
        return loss

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
        return optimizer
4.定义训练集

定义DataLoader,这一点跟PyTorch调模型的流程一样,如下调用了MNIST公开数据集

dataset = MNIST(os.getcwd(), download=True, transform=transforms.ToTensor())
train_loader = DataLoader(dataset)
5.训练数据

使用Lightning来处理所有的训练,如下代码。

# model 模型
autoencoder = LitAutoEncoder(Encoder(), Decoder())

# train model 训练
trainer = pl.Trainer()
trainer.fit(model=autoencoder, train_dataloaders=train_loader)

一般的训练过程,需要设计如下代码,进行遍历和循环训练,Lightning会消除这些繁琐的过程,使用Lightning,可以将所有这些技术混合在一起,而无需每次都重写一个新的循环。文章来源地址https://www.toymoban.com/news/detail-605108.html

autoencoder = LitAutoEncoder(Encoder(), Decoder())
optimizer = autoencoder.configure_optimizers()

for batch_idx, batch in enumerate(train_loader):
    loss = autoencoder.training_step(batch, batch_idx)

    loss.backward()
    optimizer.step()
    optimizer.zero_grad()
完整代码
# coding:utf-8
import torch, torch.nn as nn, torch.utils.data as data, torchvision as tv, torch.nn.functional as F
import lightning as L

# --------------------------------
# Step 1: 定义一个 LightningModule
# --------------------------------
# A LightningModule (nn.Module subclass) defines a full *system*
# (例如: an LLM, diffusion model, autoencoder, or simple image classifier).


class LitAutoEncoder(L.LightningModule):
    def __init__(self):
        super().__init__()
        self.encoder = nn.Sequential(nn.Linear(28 * 28, 128), nn.ReLU(), nn.Linear(128, 3))
        self.decoder = nn.Sequential(nn.Linear(3, 128), nn.ReLU(), nn.Linear(128, 28 * 28))

    def forward(self, x):
        # forward 定义了一次 预测/推理 行为
        embedding = self.encoder(x)
        return embedding

    def training_step(self, batch, batch_idx):
        # training_step 定义了一次训练的迭代, 和forward相互独立
        x, y = batch
        x = x.view(x.size(0), -1)
        z = self.encoder(x)
        x_hat = self.decoder(z)
        loss = F.mse_loss(x_hat, x)
        self.log("train_loss", loss)
        return loss

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
        return optimizer

# -------------------
# Step 2: 定义数据集
# -------------------
dataset = tv.datasets.MNIST(".", download=True, transform=tv.transforms.ToTensor())
train, val = data.random_split(dataset, [55000, 5000])

# -------------------
# Step 3: 开始训练
# -------------------
autoencoder = LitAutoEncoder()
trainer = L.Trainer(accelerator="gpu")	
trainer.fit(autoencoder, data.DataLoader(train,batch_size=128), data.DataLoader(val))

到了这里,关于PyTorch Lightning快速学习教程一:快速训练一个基础模型的文章就介绍完了。如果您还想了解更多内容,请在右上角搜索TOY模板网以前的文章或继续浏览下面的相关文章,希望大家以后多多支持TOY模板网!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处: 如若内容造成侵权/违法违规/事实不符,请点击违法举报进行投诉反馈,一经查实,立即删除!

领支付宝红包 赞助服务器费用

相关文章

  • dbGet 快速学习教程

    往期文章链接: innovus/ICC2: 命令对照表 常用dbGet命令 dbGet是innovus/encounter工具自带的\\\"database access command\\\"命令中的一部分,它几乎可以用来获取设计相关的一切信息。 输入dbGet 按 [Tab] 键,能看到三个选项,分别是head / top /selected。这三个选项所代表的意义如下: head --- 工艺信息

    2024年02月09日
    浏览(75)
  • (9)OpenCV深度学习系列教程——PyTorch入门

    作者:禅与计算机程序设计艺术 PyTorch是一个由Facebook开发的开源机器学习框架,它提供了一整套用于训练、评估和部署深度学习模型的工具和方法。随着深度学习在各个领域的应用越来越广泛,PyTorch作为一个成熟的框架已经成为机器学习研究人员的必备工具。本系列教程从

    2024年02月07日
    浏览(44)
  • 【深度学习】AIGC ,ControlNet 论文,原理,训练,部署,实战,教程(三)

    第一篇:https://qq742971636.blog.csdn.net/article/details/131531168 目前 ControlNet 1.1 还在建设,本文这里使用源码 https://github.com/lllyasviel/ControlNet/tree/main。 此外还需要下载模型文件:https://huggingface.co/lllyasviel/ControlNet 发布在huggingface了,如何下载huggingface的模型文件,使用指令: 详细lo

    2024年02月12日
    浏览(45)
  • 【深度学习】AIGC ,ControlNet 论文,原理,训练,部署,实战,教程(一)

    论文:https://arxiv.53yu.com/pdf/2302.05543 代码:https://github.com/lllyasviel/ControlNet 得分几个博客完成这个事情的记录了,此篇是第一篇,摘录了一些论文内容。ControlNet 的原理极为朴实无华(对每个block添加zero conv连接),但却非常有效地减少了训练资源和训练时间,针对不同领域任

    2024年02月15日
    浏览(39)
  • 完整教程:深度学习环境配置(GPU条件&pytorch)

    如果是python小白,强烈推荐B站小土堆的视频,讲得很清晰(但需要花些时间),地址如下: 最详细的 Windows 下 PyTorch 入门深度学习环境安装与配置 CPU GPU 版 如果有些基础,跟着往下看就行。 配置 作用 Anaconda 灵活切换python运行环境、高效使用python包 GPU 软硬件:硬件基础(

    2024年02月15日
    浏览(40)
  • 最简单Anaconda+PyTorch深度学习环境配置教程

    深度学习小白从零开始学习配置环境,记录一下踩过的雷坑,做个学习笔记。 配置了好几次之后总结出来的最简单,试错成本最小的方案,分享给大家~ 安装顺序:Anaconda+CUDA+ CuDnn+Pytorch  Anaconda ,中文 大蟒蛇 ,是一个开源的Python发行版本,其包含了conda、Python等180多个科学

    2024年02月02日
    浏览(64)
  • 深度学习框架教程:介绍一些流行的深度学习框架 (如TensorFlow、PyTorch等)

    目录 一、引言 二、TensorFlow 三、Keras 四、PyTorch 五、技巧与最佳实践

    2024年02月02日
    浏览(44)
  • 【深度学习】Pytorch 系列教程(十二):PyTorch数据结构:4、数据集(Dataset)

             目录 一、前言 二、实验环境 三、PyTorch数据结构 0、分类 1、张量(Tensor) 2、张量操作(Tensor Operations) 3、变量(Variable) 4、数据集(Dataset) 随机洗牌           ChatGPT:         PyTorch是一个开源的机器学习框架,广泛应用于深度学习领域。它提供了丰富

    2024年02月07日
    浏览(43)
  • 配置Pytorch(深度学习)环境极其详细教程,解释按钮和命令

     打开  依次点击下面这个  开始创建 下面几个选项分别是 已经安装的 没有安装的 可以更新的 已经删除的 所有的  然后去pycharm里选到把这些新创建的环境下的python.exe这个解释器添加进去,就成功让程序在这个环境里运行了  先点圆圈里的内容,然后那两个随便点一个 点

    2024年02月08日
    浏览(44)
  • Docker 快速上手学习入门教程

    目录 1、docker 的基础概念 2、怎样打包和运行一个应用程序? 3、如何对 docker 中的应用程序进行修改? 4、如何对创建的镜像进行共享? 5、如何使用 volumes 名称对容器中的数据进行存储?// 数据挂载 6、另一种挂载方式:目录挂载 7、实现容器之间的相互通信 8、使用 Docker

    2024年02月09日
    浏览(39)

觉得文章有用就打赏一下文章作者

支付宝扫一扫打赏

博客赞助

微信扫一扫打赏

请作者喝杯咖啡吧~博客赞助

支付宝扫一扫领取红包,优惠每天领

二维码1

领取红包

二维码2

领红包