PyTorch Lightning教程二:验证、测试、checkpoint、早停策略

这篇具有很好参考价值的文章主要介绍了PyTorch Lightning教程二:验证、测试、checkpoint、早停策略。希望对大家有所帮助。如果存在错误或未考虑完全的地方,请大家不吝赐教,您也可以点击"举报违法"按钮提交疑问。

介绍:上一期介绍了如何利用PyTorch Lightning搭建并训练一个模型(仅使用训练集),为了保证模型可以泛化到未见过的数据上,数据集通常被分为训练和测试两个集合,测试集与训练集相互独立,用以测试模型的泛化能力。本期通过增加验证和测试集来达到该目的,同时,还引入checkpoint和早停策略,以得到模型最佳权重。

相关链接:https://lightning.ai/docs/pytorch/stable/levels/basic_level_2.html

训练集、验证集、测试集的使用

1.添加依赖,获取训练集和测试集

添加相应的依赖,同时使用MNIST数据集,获取训练和测试集

import torch.utils.data as data
from torchvision import datasets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader

# 加载数据(测试集,train=False)
transform = transforms.ToTensor()
train_set = datasets.MNIST(root="MNIST", download=True, train=True, transform=transform)
test_set = datasets.MNIST(root="MNIST", download=True, train=False, transform=transform)
2.实现并调用test_step

在定义LightningModule中,实现test_step方法;在外部,调用test方法

class LitAutoEncoder(pl.LightningModule):
    def training_step(self, batch, batch_idx):
        ...

    def test_step(self, batch, batch_idx): # 测试,该方法与training_step相似
        x, y = batch
        x = x.view(x.size(0), -1)
        z = self.encoder(x)
        x_hat = self.decoder(z)
        test_loss = F.mse_loss(x_hat, x)
        self.log("test_loss", test_loss)

# 初始化Trainer
trainer = Trainer()

# 执行test方法
trainer.test(model, dataloaders=DataLoader(test_set))
3.实现并调用验证集

通常使用torch.utils.data中的方法,将训练集中的一部分数据化为验证集

# 训练集中的20%数据划为验证集
train_set_size = int(len(train_set) * 0.8)
valid_set_size = len(train_set) - train_set_size

# 拆分,使用data.random_split方法
seed = torch.Generator().manual_seed(42)
train_set, valid_set = data.random_split(train_set, [train_set_size, valid_set_size], generator=seed)

与测试集一样,需要在定义LightningModule中,实现validation_step方法;在外部,调用fit方法

class LitAutoEncoder(pl.LightningModule):
    def training_step(self, batch, batch_idx):
        ...

    def validation_step(self, batch, batch_idx):
        x, y = batch
        x = x.view(x.size(0), -1)
        z = self.encoder(x)
        x_hat = self.decoder(z)
        val_loss = F.mse_loss(x_hat, x)
        self.log("val_loss", val_loss)
    
    def test_step(self, batch, batch_idx):
        ...
# 调用torch.utils.data中的DataLoader对训练和测试集进行封装
train_loader = DataLoader(train_set)
valid_loader = DataLoader(valid_set)

# 在fit方法中,引入valid_loader,即验证集
trainer = Trainer()
trainer.fit(model, train_loader, valid_loader)

checkpoint

checkpoint有两个作用,一是能得到每一次epoch后的模型权重,能得到最佳表现的权重;二是能够在中断或停止后,继续在当前checkpoint处,继续训练。在Lightning中的checkpoint,包含模型的整个内部状态,这与普通的PyTorch不同,即使在最复杂的分布式训练环境中,Lightning也可以保存恢复模型所需的一切。包含以下状态:

  • 16-bit scaling factor (若使用16精度训练)
  • Current epoch
  • Global step
  • LightningModule’s state_dict
  • State of all optimizers
  • State of all learning rate schedulers
  • State of all callbacks (for stateful callbacks)
  • State of datamodule (for stateful datamodules)
  • The hyperparameters (init arguments) with which the model was created
  • The hyperparameters (init arguments) with which the datamodule was created
  • State of Loops
保存与调用方法
# 保存方法,可自定义default_root_dir路径,若不设置路径,将会自动保存
trainer = Trainer(default_root_dir="some/path/")

# 调用方法
model = MyLightningModule.load_from_checkpoint("/path/to/checkpoint.ckpt")
model.eval()	# disable randomness, dropout, etc...
y_hat = model(x)

调用,还可以使用torch的方法

checkpoint = torch.load(checkpoint, map_location=lambda storage, loc: storage)
print(checkpoint["hyper_parameters"])
# {"learning_rate": the_value, "another_parameter": the_other_value}

也可以实现重现,例如模型LitModel(in_dim=32, out_dim=10)

# 使用 in_dim=32, out_dim=10
model = LitModel.load_from_checkpoint(PATH)
# 使用 in_dim=128, out_dim=10
model = LitModel.load_from_checkpoint(PATH, in_dim=128, out_dim=10)

Lightning和PyTorch完全兼容

checkpoint = torch.load(CKPT_PATH)
encoder_weights = checkpoint["encoder"]
decoder_weights = checkpoint["decoder"]

设置checkpoint不可见

trainer = Trainer(enable_checkpointing=False)

如果想全部重新恢复

model = LitModel()
trainer = Trainer()

自动恢复所有相关参数 model, epoch, step, LR schedulers, etc…

trainer.fit(model, ckpt_path="some/path/to/my_checkpoint.ckpt")

早停策略

EarlyStopping Callback

在Lightning中,早停回调步骤如下:文章来源地址https://www.toymoban.com/news/detail-603699.html

  • Import EarlyStopping callback. 载入EarlyStopping回调方法
  • Log the metric you want to monitor using log() method. 加载日志方法
  • Init the callback, and set monitor to the logged metric of your choice. 设置monitor
  • Set the mode based on the metric needs to be monitored. 设置mode
  • Pass the EarlyStopping callback to the Trainer callbacks flag. 调入EarlyStropping
from lightning.pytorch.callbacks.early_stopping import EarlyStopping

class LitModel(LightningModule):
    def validation_step(self, batch, batch_idx):
        loss = ...
        self.log("val_loss", loss)

model = LitModel()
trainer = Trainer(callbacks=[EarlyStopping(monitor="val_loss", mode="min")])
trainer.fit(model)

# 也可以使用自定义的EarlyStopping策略
early_stop_callback = EarlyStopping(monitor="val_accuracy", min_delta=0.00, patience=3, verbose=False, mode="max")
trainer = Trainer(callbacks=[early_stop_callback])
# EarlyStopping的文档链接https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.callbacks.EarlyStopping.html#lightning.pytorch.callbacks.EarlyStopping
注意
  • EarlyStopping默认在一次Validation后调用,但是Validation可以自定义多少次epoch后进行一次验证,例如check_val_every_n_epoch and val_check_interval

完整代码

# coding:utf-8
import torch
import torch.nn as nn
import torch.utils.data as data
from torchvision import datasets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import torch.nn.functional as F
import lightning as L

# --------------------------------
# Step 1: 定义模型
# --------------------------------
class LitAutoEncoder(L.LightningModule):
    def __init__(self):
        super().__init__()
        self.encoder = nn.Sequential(nn.Linear(28 * 28, 128), nn.ReLU(), nn.Linear(128, 3))
        self.decoder = nn.Sequential(nn.Linear(3, 128), nn.ReLU(), nn.Linear(128, 28 * 28))

    def training_step(self, batch, batch_idx):
        x, y = batch
        x = x.view(x.size(0), -1)
        z = self.encoder(x)
        x_hat = self.decoder(z)
        loss = F.mse_loss(x_hat, x)
        self.log("train_loss", loss)
        return loss

    def test_step(self, batch, batch_idx):  # 测试,该方法与training_step相似
        x, y = batch
        x = x.view(x.size(0), -1)
        z = self.encoder(x)
        x_hat = self.decoder(z)
        test_loss = F.mse_loss(x_hat, x)
        self.log("test_loss", test_loss)

    def validation_step(self, batch, batch_idx):
        x, y = batch
        x = x.view(x.size(0), -1)
        z = self.encoder(x)
        x_hat = self.decoder(z)
        val_loss = F.mse_loss(x_hat, x)
        self.log("val_loss", val_loss)

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
        return optimizer

    def forward(self, x):
        # forward 定义了一次 预测/推理 行为
        embedding = self.encoder(x)
        return embedding
# --------------------------------
# Step 2: 加载数据+模型
# --------------------------------
transform = transforms.ToTensor()
train_set = datasets.MNIST(root="MNIST", download=True, train=True, transform=transform)
test_set = datasets.MNIST(root="MNIST", download=True, train=False, transform=transform)

# 训练集中的20%数据划为验证集
train_set_size = int(len(train_set) * 0.8)
valid_set_size = len(train_set) - train_set_size

# 拆分,使用data.random_split方法
seed = torch.Generator().manual_seed(42)
train_set, valid_set = data.random_split(train_set, [train_set_size, valid_set_size], generator=seed)
train_loader = DataLoader(train_set)
valid_loader = DataLoader(valid_set)

autoencoder = LitAutoEncoder()
# --------------------------------
# Step 3: 训练+验证+测试
# --------------------------------
# 训练+验证
trainer = L.Trainer(default_root_dir="some/path/")	# 这里自定义需要保存的路径
trainer.fit(autoencoder, train_loader, valid_loader)

# 测试
trainer.test(autoencoder, dataloaders=DataLoader(test_set))

到了这里,关于PyTorch Lightning教程二:验证、测试、checkpoint、早停策略的文章就介绍完了。如果您还想了解更多内容,请在右上角搜索TOY模板网以前的文章或继续浏览下面的相关文章,希望大家以后多多支持TOY模板网!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处: 如若内容造成侵权/违法违规/事实不符,请点击违法举报进行投诉反馈,一经查实,立即删除!

领支付宝红包 赞助服务器费用

相关文章

  • PyTorch Lightning快速学习教程一:快速训练一个基础模型

    粉丝量突破1200了!找到了喜欢的岗位,毕业上班刚好也有20天,为了督促自己终身学习的态度,继续开始坚持写写博客,沉淀并总结知识! 介绍:PyTorch Lightning是针对科研人员、机器学习开发者专门设计的,能够快速复用代码的一个工具,避免了因为每次都编写相似的代码而

    2024年02月16日
    浏览(55)
  • pytorch lightning和pytorch版本对应

    参见官方文档: https://lightning.ai/docs/pytorch/latest/versioning.html#compatibility-matrix 下图左一列( lightning.pytorch )安装命令: pip install lightning --use-feature=2020-resolver 下图左二列( pytorch_lightning )安装命令: pip install pytorch_lightning --use-feature=2020-resolver 加 --use-feature=2020-resolver 解决依赖

    2024年02月12日
    浏览(54)
  • pytorch lightning 入门

    翻译自官方文档 前置知识 :推荐pytorch 目标 :通过PL中7个关键步骤了解PL工作流程 PL是基于pytorch的高层API,自带丰富的工具为AI学者和工程师快速创建高性能模型,去除繁琐的重复流程同时保持灵活性。 使用组织好的pytorch代码,PL可以: 避免重复流程。比如gpu设置,device设

    2023年04月08日
    浏览(68)
  • Pytorch Lightning 训练更新次数

    假设一共1000个samples,batch size=4,因此一个epoch会有250 iterations,也就是会更新250次 当设置Trainer时 这个 max_steps 指的是最多更新的次数,这里也就是40次,而 accumulate_grad_batches 指的是每次更新前积累多少个batch,这里为2 因此,每次更新前实际上积累了2 * 4 = 8个samples的gradient

    2024年02月15日
    浏览(45)
  • Pytorch-lightning简介

    pytorch-lighting(简称pl),它其实就是一个轻量级的PyTorch库,用于高性能人工智能研究的轻量级PyTorch包装器。缩放你的模型,而不是样板。 研究代码(位于LightningModule中)。 工程代码(由Trainer处理)。 非必要的研究代码(日志记录等…这在Callbacks中进行) 暂时先上传这些内容

    2024年02月10日
    浏览(39)
  • 变分自编码器(VAE)PyTorch Lightning 实现

    ✅作者简介:人工智能专业本科在读,喜欢计算机与编程,写博客记录自己的学习历程。 🍎个人主页:小嗷犬的个人主页 🍊个人网站:小嗷犬的技术小站 🥭个人信条:为天地立心,为生民立命,为往圣继绝学,为万世开太平。 变分自编码器 (Variational Autoencoder,VAE)是一

    2024年02月21日
    浏览(51)
  • (5)深度学习学习笔记-多层感知机-pytorch lightning版

    pytorch lighting是导师推荐给我学习的一个轻量级的PyTorch库,代码干净简洁,使用pl更容易理解ML代码,对于初学者的我还是相对友好的。 pytorch lightning官网网址 https://lightning.ai/docs/pytorch/stable/levels/core_skills.html 代码如下: 代码如下:(可以直接把download改为true下载) 更多pl的方

    2024年02月12日
    浏览(44)
  • 版本匹配指南:PyTorch版本、Python版本和pytorch_lightning版本的对应关系

    版本匹配指南:PyTorch版本、Python版本和pytorch_lightning版本的对应关系 🌈 欢迎莅临 我的个人主页👈这里是我 静心耕耘 深度学习领域、 真诚分享 知识与智慧的小天地!🎇 🎓 博主简介: 我是 高斯小哥 ,一名来自985高校的普通本硕生,曾有幸在中科院顶刊发表过 一作论文

    2024年04月17日
    浏览(65)
  • 9、Flink四大基石之Checkpoint容错机制详解及示例(checkpoint配置、重启策略、手动恢复checkpoint和savepoint)

    一、Flink 专栏 Flink 专栏系统介绍某一知识点,并辅以具体的示例进行说明。 1、Flink 部署系列 本部分介绍Flink的部署、配置相关基础内容。 2、Flink基础系列 本部分介绍Flink 的基础部分,比如术语、架构、编程模型、编程指南、基本的datastream api用法、四大基石等内容。 3、

    2024年02月04日
    浏览(47)
  • 关于安装 PyTorch-Lightning 的一些问题(GPU版)

    官网地址: PyTorch PyTorch-Lightning 1、不能直接使用 pip install pytorch-lightning  ,否则如下图会直接卸载掉你的torch而安装cpu版本的torch。 2、在安装pytorch-lightning时一定注意自己的torch是pip安装还是conda安装,两者要保持一致,否则也会导致你的torch版本被替换。 正确安装方式: p

    2024年02月02日
    浏览(33)

觉得文章有用就打赏一下文章作者

支付宝扫一扫打赏

博客赞助

微信扫一扫打赏

请作者喝杯咖啡吧~博客赞助

支付宝扫一扫领取红包,优惠每天领

二维码1

领取红包

二维码2

领红包