PyTorch Lightning教程八:用模型预测,部署

这篇具有很好参考价值的文章主要介绍了PyTorch Lightning教程八:用模型预测,部署。希望对大家有所帮助。如果存在错误或未考虑完全的地方,请大家不吝赐教,您也可以点击"举报违法"按钮提交疑问。

关于Checkpoints的内容在教程2里已经有了详细的说明,在本节,需要用它来利用模型进行预测

加载checkpoint并预测

使用模型进行预测的最简单方法是使用LightningModule中的load_from_checkpoint加载权重。

model = LitModel.load_from_checkpoint("best_model.ckpt")
model.eval()
x = torch.randn(1, 64)

with torch.no_grad():
    y_hat = model(x)

predict_step方法

加载检查点并进行预测仍然会在预测阶段的epoch留下许多boilerplate,LightningModule中的预测步骤删除了这个boilerplate 。

class MyModel(LightningModule):
    def predict_step(self, batch, batch_idx, dataloader_idx=0):
        return self(batch)

并将任何dataloader传递给Lightning Trainer

data_loader = DataLoader(...)
model = MyModel()
trainer = Trainer()
predictions = trainer.predict(model, data_loader)

预测逻辑

当需要向数据添加复杂的预处理或后处理时,使用predict_step方法。例如,这里我们使用Monte Carlo Dropout 进行预测

class LitMCdropoutModel(pl.LightningModule):
    def __init__(self, model, mc_iteration):
        super().__init__()
        self.model = model
        self.dropout = nn.Dropout()
        self.mc_iteration = mc_iteration

    def predict_step(self, batch, batch_idx):
        # enable Monte Carlo Dropout
        self.dropout.train()

        # take average of `self.mc_iteration` iterations
        pred = [self.dropout(self.model(x)).unsqueeze(0) for _ in range(self.mc_iteration)]
        pred = torch.vstack(pred).mean(dim=0)
        return pred

启用分布式推理

通过使用Lightning中的predict_step,可以使用BasePredictionWriter进行分布式推理。

import torch
from lightning.pytorch.callbacks import BasePredictionWriter


class CustomWriter(BasePredictionWriter):
    def __init__(self, output_dir, write_interval):
        super().__init__(write_interval)
        self.output_dir = output_dir

    def write_on_epoch_end(self, trainer, pl_module, predictions, batch_indices):
        # 在'output_dir'中创建N (num进程)个文件,每个文件都包含对其各自rank的预测
        torch.save(predictions, os.path.join(self.output_dir, f"predictions_{trainer.global_rank}.pt"))

        # 可以保存'batch_indices',以便从预测数据中获取有关数据索引的信息
        torch.save(batch_indices, os.path.join(self.output_dir, f"batch_indices_{trainer.global_rank}.pt"))


# 可以设置writer_interval="batch"
pred_writer = CustomWriter(output_dir="pred_path", write_interval="epoch")
trainer = Trainer(accelerator="gpu", strategy="ddp", devices=8, callbacks=[pred_writer])
model = BoringModel()
trainer.predict(model, return_predictions=False)

也可以加载保存的checkpoint,把它当作一个普通的torch.nn.Module来使用。可以提取所有的torch.nn.Module,并在训练后使用LightningModule保存的checkpoint加载权重。建议从LightningModule的init和forward方法中复制明确的实现。文章来源地址https://www.toymoban.com/news/detail-650149.html

class Encoder(nn.Module):
    ...


class Decoder(nn.Module):
    ...


class AutoEncoderProd(nn.Module):
    def __init__(self):
        super().__init__()
        self.encoder = Encoder()
        self.decoder = Decoder()

    def forward(self, x):
        return self.encoder(x)


class AutoEncoderSystem(LightningModule):
    def __init__(self):
        super().__init__()
        self.auto_encoder = AutoEncoderProd()

    def forward(self, x):
        return self.auto_encoder.encoder(x)

    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self.auto_encoder.encoder(x)
        y_hat = self.auto_encoder.decoder(y_hat)
        loss = ...
        return loss


# 训练
trainer = Trainer(devices=2, accelerator="gpu", strategy="ddp")
model = AutoEncoderSystem()
trainer.fit(model, train_dataloader, val_dataloader)
trainer.save_checkpoint("best_model.ckpt")


# 创建PyTorch模型并加载checkpoint权重
model = AutoEncoderProd()
checkpoint = torch.load("best_model.ckpt")
hyper_parameters = checkpoint["hyper_parameters"]

# 恢复超参数
model = AutoEncoderProd(**hyper_parameters)

model_weights = checkpoint["state_dict"]

# 通过 dropping `auto_encoder.` 更新key值
for key in list(model_weights):
    model_weights[key.replace("auto_encoder.", "")] = model_weights.pop(key)

model.load_state_dict(model_weights)
model.eval()
x = torch.randn(1, 64)

with torch.no_grad():
    y_hat = model(x)

到了这里,关于PyTorch Lightning教程八:用模型预测,部署的文章就介绍完了。如果您还想了解更多内容,请在右上角搜索TOY模板网以前的文章或继续浏览下面的相关文章,希望大家以后多多支持TOY模板网!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处: 如若内容造成侵权/违法违规/事实不符,请点击违法举报进行投诉反馈,一经查实,立即删除!

领支付宝红包 赞助服务器费用

相关文章

  • PyTorch Lightning教程五:Debug调试

    如果遇到了这样一个问题,当一次训练模型花了好几天,结果突然在验证或测试的时候崩掉了,这个时候其实是很奔溃的,主要还是由于没有提前知道哪些时候会出现什么问题,本节会引入Lightning的Debug方案 1.fast_dev_run参数 Trainer中的fast_dev_run参数通过你的训练器运行5批训练

    2024年02月14日
    浏览(45)
  • PyTorch Lightning教程七:可视化

    本节指导如何利用Lightning进行可视化和监控模型 为何需要跟踪参数 在模型开发中,我们跟踪感兴趣的值,例如validation_loss,以可视化模型的学习过程。模型开发就像驾驶一辆没有窗户的汽车,图表和日志提供了窗口,让我们知道该把车开到哪里。有了Lightning,几乎可以可视

    2024年02月14日
    浏览(39)
  • PyTorch Lightning教程四:超参数的使用

    如果需要和命令行接口进行交互,可以使用Python中的argparse包,快捷方便,对于Lightning而言,可以利用它,在命令行窗口中,直接配置超参数等操作,但也可以使用LightningCLI的方法,更加轻便简单。 ArgumentParser ArgumentParser是Python的内置特性,进而构建CLI程序,我们可以使用它

    2024年02月15日
    浏览(35)
  • PyTorch Lightning教程二:验证、测试、checkpoint、早停策略

    介绍:上一期介绍了如何利用PyTorch Lightning搭建并训练一个模型(仅使用训练集),为了保证模型可以泛化到未见过的数据上,数据集通常被分为训练和测试两个集合,测试集与训练集相互独立,用以测试模型的泛化能力。本期通过增加验证和测试集来达到该目的,同时,还

    2024年02月16日
    浏览(35)
  • pytorch lightning和pytorch版本对应

    参见官方文档: https://lightning.ai/docs/pytorch/latest/versioning.html#compatibility-matrix 下图左一列( lightning.pytorch )安装命令: pip install lightning --use-feature=2020-resolver 下图左二列( pytorch_lightning )安装命令: pip install pytorch_lightning --use-feature=2020-resolver 加 --use-feature=2020-resolver 解决依赖

    2024年02月12日
    浏览(54)
  • 人工智能(pytorch)搭建模型9-pytorch搭建一个ELMo模型,实现训练过程

    大家好,我是微学AI,今天给大家介绍一下人工智能(pytorch)搭建模型9-pytorch搭建一个ELMo模型,实现训练过程,本文将介绍如何使用PyTorch搭建ELMo模型,包括ELMo模型的原理、数据样例、模型训练、损失值和准确率的打印以及预测。文章将提供完整的代码实现。 ELMo模型简介 数据

    2024年02月07日
    浏览(67)
  • 人工智能(Pytorch)搭建模型6-使用Pytorch搭建卷积神经网络ResNet模型

    大家好,我是微学AI,今天给大家介绍一下人工智能(Pytorch)搭建模型6-使用Pytorch搭建卷积神经网络ResNet模型,在本文中,我们将学习如何使用PyTorch搭建卷积神经网络ResNet模型,并在生成的假数据上进行训练和测试。本文将涵盖这些内容:ResNet模型简介、ResNet模型结构、生成假

    2024年02月06日
    浏览(78)
  • 人工智能(pytorch)搭建模型12-pytorch搭建BiGRU模型,利用正态分布数据训练该模型

    大家好,我是微学AI,今天给大家介绍一下人工智能(pytorch)搭建模型12-pytorch搭建BiGRU模型,利用正态分布数据训练该模型。本文将介绍一种基于PyTorch的BiGRU模型应用项目。我们将首先解释BiGRU模型的原理,然后使用PyTorch搭建模型,并提供模型代码和数据样例。接下来,我们将

    2024年02月09日
    浏览(68)
  • pytorch lightning 入门

    翻译自官方文档 前置知识 :推荐pytorch 目标 :通过PL中7个关键步骤了解PL工作流程 PL是基于pytorch的高层API,自带丰富的工具为AI学者和工程师快速创建高性能模型,去除繁琐的重复流程同时保持灵活性。 使用组织好的pytorch代码,PL可以: 避免重复流程。比如gpu设置,device设

    2023年04月08日
    浏览(68)
  • 人工智能(pytorch)搭建模型13-pytorch搭建RBM(受限玻尔兹曼机)模型,调通模型的训练与测试

    大家好,我是微学AI,今天给大家介绍一下人工智能(pytorch)搭建模型13-pytorch搭建RBM(受限玻尔兹曼机)模型,调通模型的训练与测试。RBM(受限玻尔兹曼机)可以在没有人工标注的情况下对数据进行学习。其原理类似于我们人类学习的过程,即通过观察、感知和记忆不同事物的特点

    2024年02月10日
    浏览(77)

觉得文章有用就打赏一下文章作者

支付宝扫一扫打赏

博客赞助

微信扫一扫打赏

请作者喝杯咖啡吧~博客赞助

支付宝扫一扫领取红包,优惠每天领

二维码1

领取红包

二维码2

领红包