PyTorch-Lightning:trining_step的自动优化

这篇具有很好参考价值的文章主要介绍了PyTorch-Lightning:trining_step的自动优化。希望对大家有所帮助。如果存在错误或未考虑完全的地方,请大家不吝赐教,您也可以点击"举报违法"按钮提交疑问。

PyTorch-Lightning:trining_step的自动优化

使用PyTorch-Lightning时,在trining_step定义损失,在没有定义损失,没有任何返回的情况下没有报错,在定义一个包含loss的多个元素字典返回时,也可以正常训练,那么到底lightning是怎么完成训练过程的。

总结:

在自动优化中,training_step必须返回一个tensor或者dict或者None(跳过),对于简单的使用,在training_step可以return一个tensor会作为Loss回传,也可以return一个字典,其中必须包括key"loss",字典中的"loss"会提取出来作为Loss回传,具体过程主要包含在lightning\pytorch\loop\sautomatic.py中的_ AutomaticOptimization()类。

PyTorch-Lightning:trining_step的自动优化,pytorch,人工智能,python

class _ AutomaticOptimization()

实现自动优化(前向,梯度清零,后向,optimizer step)

在training_epoch_loop中会调用这个类的run函数。

def run

首先通过 _make_closure得到一个closure,详见def _make_closure,最后返回一个字典,如果我们在training_step只return了一个loss tensor则字典只有一个’loss’键值对,如果return了一个字典,则包含其他键值对。

可以看到调用了_ optimizer_step,_ optimizer_step经过层层调用,最后会调用torch默认的optimizer.zero_grad,而我们通过 make_closure得到的closure作为参数传入,具体而言是调用了closure类的_ call __()方法。

def run(self, optimizer: Optimizer, batch_idx: int, kwargs: OrderedDict) -> _OUTPUTS_TYPE:
        closure = self._make_closure(kwargs, optimizer, batch_idx)

        if (
            # when the strategy handles accumulation, we want to always call the optimizer step
            not self.trainer.strategy.handles_gradient_accumulation and self.trainer.fit_loop._should_accumulate()
        ):
            # For gradient accumulation

            # -------------------
            # calculate loss (train step + train step end)
            # -------------------
            # automatic_optimization=True: perform ddp sync only when performing optimizer_step
            with _block_parallel_sync_behavior(self.trainer.strategy, block=True):
                closure()

        # ------------------------------
        # BACKWARD PASS
        # ------------------------------
        # gradient update with accumulated gradients
        else:
            self._optimizer_step(batch_idx, closure)

        result = closure.consume_result()
        if result.loss is None:
            return {}
        return result.asdict()

def _make_closure

创建一个closure对象,来捕捉给定的参数并且运行’training_step’和可选的其他如backword和zero_grad函数

比较重要的是step_fn,在这里调用了_training_step,得到的是一个存储我们在定义模型时重写的training step的输出所构成ClosureResult数据类。具体见def _training_step

def _make_closure(self, kwargs: OrderedDict, optimizer: Optimizer, batch_idx: int) -> Closure:

        step_fn = self._make_step_fn(kwargs)
        backward_fn = self._make_backward_fn(optimizer)
        zero_grad_fn = self._make_zero_grad_fn(batch_idx, optimizer)
        return Closure(step_fn=step_fn, backward_fn=backward_fn, zero_grad_fn=zero_grad_fn)

def _training_step

通过hook函数实现真正的训练step,返回一个存储training step输出的ClosureResult数据类。

将我们在定义模型时定义的lightning.pytorch.core.LightningModule.training_step的输出作为参数传入存储容器class ClosureResult的from_training_step_output方法,见class Closure

class ClosureResult():

一个数据类,包含closure_loss,loss,extra

    closure_loss: Optional[Tensor]
    loss: Optional[Tensor] = field(init=False, default=None)
    extra: Dict[str, Any] = field(default_factory=dict)
def from_training_step_output

一个类方法,如果我们在training_step定义的返回是一个字典,则我们会将key值为"loss"的value赋值给closure_loss,而其余的键值对赋值给extra字典,如果返回的既不是包含"loss"的字典也不是tensor,则会报错。当我们在training_step不设定返回,则自然为None,但是不会报错。

class Closure

闭包是将外部作用域中的变量绑定到对这些变量进行计算的函数变量,而不将它们明确地作为输入。这样做的好处是可以将闭包传递给对象,之后可以像函数一样调用它,但不需要传入任何参数。

在lightning,用于自动优化的Closure类将training_step和backward, zero_grad三个基本的闭包结合在一起。

这个Closure得到training循环中的结果之后传入torch.optim.Optimizer.step。

参数:

  • step_fn: 这里是一个存储了training step输出的ClosureResult数据类,见def _training_step
  • backward_fn: 梯度回传函数
  • zero_grad_fn: 梯度清零函数

按照顺序,会先检查得到loss,之后调用梯度清零函数,最后调用梯度回传函数文章来源地址https://www.toymoban.com/news/detail-852831.html

class Closure(AbstractClosure[ClosureResult]):

    warning_cache = WarningCache()

    def __init__(
        self,
        step_fn: Callable[[], ClosureResult],
        backward_fn: Optional[Callable[[Tensor], None]] = None,
        zero_grad_fn: Optional[Callable[[], None]] = None,
    ):
        super().__init__()
        self._step_fn = step_fn
        self._backward_fn = backward_fn
        self._zero_grad_fn = zero_grad_fn

    @override
    @torch.enable_grad()
    def closure(self, *args: Any, **kwargs: Any) -> ClosureResult:
        step_output = self._step_fn()

        if step_output.closure_loss is None:
            self.warning_cache.warn("`training_step` returned `None`. If this was on purpose, ignore this warning...")

        if self._zero_grad_fn is not None:
            self._zero_grad_fn()

        if self._backward_fn is not None and step_output.closure_loss is not None:
            self._backward_fn(step_output.closure_loss)

        return step_output

    @override
    def __call__(self, *args: Any, **kwargs: Any) -> Optional[Tensor]:
        self._result = self.closure(*args, **kwargs)
        return self._result.loss

到了这里,关于PyTorch-Lightning:trining_step的自动优化的文章就介绍完了。如果您还想了解更多内容,请在右上角搜索TOY模板网以前的文章或继续浏览下面的相关文章,希望大家以后多多支持TOY模板网!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处: 如若内容造成侵权/违法违规/事实不符,请点击违法举报进行投诉反馈,一经查实,立即删除!

领支付宝红包 赞助服务器费用

相关文章

  • Pytorch Lightning 训练更新次数

    假设一共1000个samples,batch size=4,因此一个epoch会有250 iterations,也就是会更新250次 当设置Trainer时 这个 max_steps 指的是最多更新的次数,这里也就是40次,而 accumulate_grad_batches 指的是每次更新前积累多少个batch,这里为2 因此,每次更新前实际上积累了2 * 4 = 8个samples的gradient

    2024年02月15日
    浏览(45)
  • PyTorch Lightning教程五:Debug调试

    如果遇到了这样一个问题,当一次训练模型花了好几天,结果突然在验证或测试的时候崩掉了,这个时候其实是很奔溃的,主要还是由于没有提前知道哪些时候会出现什么问题,本节会引入Lightning的Debug方案 1.fast_dev_run参数 Trainer中的fast_dev_run参数通过你的训练器运行5批训练

    2024年02月14日
    浏览(45)
  • PyTorch Lightning教程七:可视化

    本节指导如何利用Lightning进行可视化和监控模型 为何需要跟踪参数 在模型开发中,我们跟踪感兴趣的值,例如validation_loss,以可视化模型的学习过程。模型开发就像驾驶一辆没有窗户的汽车,图表和日志提供了窗口,让我们知道该把车开到哪里。有了Lightning,几乎可以可视

    2024年02月14日
    浏览(39)
  • PyTorch Lightning教程八:用模型预测,部署

    关于Checkpoints的内容在教程2里已经有了详细的说明,在本节,需要用它来利用模型进行预测 加载checkpoint并预测 使用模型进行预测的最简单方法是使用LightningModule中的load_from_checkpoint加载权重。 predict_step方法 加载检查点并进行预测仍然会在预测阶段的epoch留下许多boilerplate,

    2024年02月12日
    浏览(39)
  • PyTorch Lightning教程四:超参数的使用

    如果需要和命令行接口进行交互,可以使用Python中的argparse包,快捷方便,对于Lightning而言,可以利用它,在命令行窗口中,直接配置超参数等操作,但也可以使用LightningCLI的方法,更加轻便简单。 ArgumentParser ArgumentParser是Python的内置特性,进而构建CLI程序,我们可以使用它

    2024年02月15日
    浏览(35)
  • PyTorch Lightning教程二:验证、测试、checkpoint、早停策略

    介绍:上一期介绍了如何利用PyTorch Lightning搭建并训练一个模型(仅使用训练集),为了保证模型可以泛化到未见过的数据上,数据集通常被分为训练和测试两个集合,测试集与训练集相互独立,用以测试模型的泛化能力。本期通过增加验证和测试集来达到该目的,同时,还

    2024年02月16日
    浏览(34)
  • 变分自编码器(VAE)PyTorch Lightning 实现

    ✅作者简介:人工智能专业本科在读,喜欢计算机与编程,写博客记录自己的学习历程。 🍎个人主页:小嗷犬的个人主页 🍊个人网站:小嗷犬的技术小站 🥭个人信条:为天地立心,为生民立命,为往圣继绝学,为万世开太平。 变分自编码器 (Variational Autoencoder,VAE)是一

    2024年02月21日
    浏览(51)
  • (5)深度学习学习笔记-多层感知机-pytorch lightning版

    pytorch lighting是导师推荐给我学习的一个轻量级的PyTorch库,代码干净简洁,使用pl更容易理解ML代码,对于初学者的我还是相对友好的。 pytorch lightning官网网址 https://lightning.ai/docs/pytorch/stable/levels/core_skills.html 代码如下: 代码如下:(可以直接把download改为true下载) 更多pl的方

    2024年02月12日
    浏览(44)
  • 版本匹配指南:PyTorch版本、Python版本和pytorch_lightning版本的对应关系

    版本匹配指南:PyTorch版本、Python版本和pytorch_lightning版本的对应关系 🌈 欢迎莅临 我的个人主页👈这里是我 静心耕耘 深度学习领域、 真诚分享 知识与智慧的小天地!🎇 🎓 博主简介: 我是 高斯小哥 ,一名来自985高校的普通本硕生,曾有幸在中科院顶刊发表过 一作论文

    2024年04月17日
    浏览(65)
  • PyTorch Lightning快速学习教程一:快速训练一个基础模型

    粉丝量突破1200了!找到了喜欢的岗位,毕业上班刚好也有20天,为了督促自己终身学习的态度,继续开始坚持写写博客,沉淀并总结知识! 介绍:PyTorch Lightning是针对科研人员、机器学习开发者专门设计的,能够快速复用代码的一个工具,避免了因为每次都编写相似的代码而

    2024年02月16日
    浏览(55)

觉得文章有用就打赏一下文章作者

支付宝扫一扫打赏

博客赞助

微信扫一扫打赏

请作者喝杯咖啡吧~博客赞助

支付宝扫一扫领取红包,优惠每天领

二维码1

领取红包

二维码2

领红包