Pytorch 中的 checkpoint

这篇具有很好参考价值的文章主要介绍了Pytorch 中的 checkpoint。希望对大家有所帮助。如果存在错误或未考虑完全的地方,请大家不吝赐教,您也可以点击"举报违法"按钮提交疑问。

当我们在谈论 Pytorch checkpoint 时,我们可能在说两件不同的事情。

第一个是 General checkpoint,用它保存模型的参数、优化器的参数,以及 Epoch, loss 等任何你想要保存的东西。我们可以利用它进行断点续训,以及后续的模型推理。长时间训练大模型时,在代码中定期保存 checkpoint 也是一个好习惯。

第二个是 Gradient checkpoint,这是一种以时间换空间的技术:用更长的计算时间,换取显卡内存。

我们分别来看一下这两件完全不同的事情。


General checkpoint

When saving a general checkpoint, you must save more than just the model’s state_dict. It is important to also save the optimizer’s state_dict, as this contains buffers and parameters that are updated as the model trains. ——SAVING AND LOADING A GENERAL CHECKPOINT IN PYTORCH

在保存检查点的时候,如果你确定模型已经训练完毕,之后加载模型时只会用它做推理,那么你可以不保存优化器的参数。
但如果之后会进行断点续训,那么优化器参数是必须要保存的。像 Adam 这种优化器,更新参数时会用到历史梯度,必须将它们保存下来。

为了保险起见,一般的建议是:save both the model’s and optimizer’s state_dict

保存模型:

torch.save({
            'epoch': EPOCH,
            'model_state_dict': net.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': LOSS,
            }, "./model.pt")

加载模型:

model = Net()
# optimizer should be the same as before
optimizer = optim.SGD(net.parameters())

checkpoint = torch.load(PATH)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
loss = checkpoint['loss']

Gradient checkpoint

前面提到,Gradient checkpoint 是一种节省显存的机制。首先的一个问题是,PyTorch 模型在训练过程中,显存存储的是什么?

PyTorch显存机制分析 这篇文章中说:

PyTorch 在进行深度学习训练的时候,有4大部分的显存开销,分别是模型参数 (parameters),模型参数的梯度 (gradients),优化器状态 (optimizer states) 以及中间激活值 (intermediate activations) 或者叫中间结果 (intermediate results)

前向传播中,中间激活值被保存下来;反向传播中,这些中间激活值被用来计算梯度,在计算完成后被销毁(释放)。

Gradient checkpoint 的思路是,在前向传播过程中不保存中间激活值;在反向传播要用到的时候再重新计算。这样当然节省了显存,但中间值被计算了两遍。


Pytorch 提供了两种使用梯度检查点的方式:torch.utils.checkpoint.checkpoint_sequential 以及 torch.utils.checkpoint.checkpoint

checkpoint_sequential

checkpoint_sequential 适用于前向传播逻辑简单的序列模型,即按照顺序执行列表中的 modules/functions。对于这种模型,可以把它分割成 N 个小块,对每一个小块做梯度检查。

import torch.nn as nn
from torch.utils.checkpoint import checkpoint_sequential

chunks = 3
model = nn.Sequential(...)
input_var = checkpoint_sequential(functions=model, segments=chunks, input=input_var, preserve_rng_state=True)

上面这段代码,把 sequential model 分成了三个小块。除了最后一个小块,其余两个小块(segment 1, segment 2)均以 torch.no_grad() 的方式进行,也就不需要保存中间激活值了。segment 1, segment 2 的输入会被保存,以便反向传播时重新计算它们的中间值。

checkpoint pytorch,pytorch,深度学习,人工智能

checkpoint

对于更复杂的模型结构,需要用 checkpoint

class CIFAR10Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.cnn_block_1 = nn.Sequential(*[
            nn.Conv2d(3, 32, 3, padding=1),
            nn.ReLU(),
            nn.Conv2d(32, 64, 3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2),
            nn.Dropout(0.25)
        ])
        self.cnn_block_2 = nn.Sequential(*[
            nn.Conv2d(64, 64, 3, padding=1),
            nn.ReLU(),
            nn.Conv2d(64, 64, 3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2),
            nn.Dropout(0.25)
        ])
        self.flatten = lambda inp: torch.flatten(inp, 1)
        self.head = nn.Sequential(*[
            nn.Linear(64 * 8 * 8, 512),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(512, 10)
        ])
    
    def forward(self, X):
        X = self.cnn_block_1(X)
        X = self.dropout_1(X)
        X = torch.utils.checkpoint.checkpoint(self.cnn_block_2, X)
        X = self.dropout_2(X)
        X = self.flatten(X)
        X = self.head(X)
        return X

我们对 cnn_block_2 设置了梯度检查——只需要给出 block of module,以及它的输入。有两点需要注意:

  1. Use of torch.utils.checkpoint.checkpoint causes simple model to diverge 这篇讨论里提到,为什么不对 cnn_block_1 设置梯度检查:
    cnn_block_1 的输入是原始输入,它的 requires_grad=False,因为我们只需要对模型权重求梯度,不需要对原始输入求梯度。而被梯度检查的模块的输出的 requires_grad 与输入的 requires_grad 保持一致。在 cnn_block_1 中,它的输出 requires_grad=False,导致模块的权重不会更新。因此这位作者建议,不要在紧跟着原始输入的模块上设置梯度检查

  2. 另一个需要注意的点是,我们可以对包含 Dropout layer 的模块设置梯度检查,但要注意 preserve_rng_state 这个参数。 Dropout layer 需要进行随机采样,随机数生成器的状态会随之改变。由于梯度检查需要进行两次前向传播,如果两次的随机数生成器的状态不一样,就会产生不同的结果。
    preserve_rng_state=True (默认),意味着程序会保存前一次随机数生成器的状态,在第二次前向传播时,保证 Dropout layer 的结果与第一次相同

官方文档对梯度检查中随机状态的解释:

Checkpointing is implemented by rerunning a forward-pass segment for each checkpointed segment during backward. This can cause persistent states like the RNG state to be advanced than they would without checkpointing. By default, checkpointing includes logic to juggle the RNG state such that checkpointed passes making use of RNG (through dropout for example) have deterministic output as compared to non-checkpointed passes. The logic to stash and restore RNG states can incur a moderate performance hit depending on the runtime of checkpointed operations.

设置 preserve_rng_state=True 会对性能造成一定程度的影响。


定点设置 General checkpoint 保存参数是一个好习惯;

Gradient checkpoint(梯度检查)是一项有用的技术,但需要在实战中练习。它可能会引起意想不到的 Bug,需要多加注意。


参考:文章来源地址https://www.toymoban.com/news/detail-577202.html

  • TORCH.UTILS.CHECKPOINT
  • Use of torch.utils.checkpoint.checkpoint causes simple model to diverge
  • PyTorch 显存机制分析

到了这里,关于Pytorch 中的 checkpoint的文章就介绍完了。如果您还想了解更多内容,请在右上角搜索TOY模板网以前的文章或继续浏览下面的相关文章,希望大家以后多多支持TOY模板网!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处: 如若内容造成侵权/违法违规/事实不符,请点击违法举报进行投诉反馈,一经查实,立即删除!

领支付宝红包 赞助服务器费用

相关文章

  • 《人工智能专栏》必读150篇 | 专栏介绍 & 专栏目录 & Python与PyTorch | 机器与深度学习 | 目标检测 | YOLOv5及改进 | YOLOv8及改进 | 关键知识点 | 工具

    各位读者们好,本专栏最近刚推出,限于个人能力有限,不免会有诸多错误,敬请私信反馈给我,接受善意的提示,后期我会改正,谢谢,感谢。 第一步 :[ 购买点击跳转 ] 第二步 : 代码函数调用关系图(全网最详尽-重要) 因文档特殊,不能在博客正确显示,请移步以下链接

    2024年02月02日
    浏览(78)
  • AI写作革命:PyTorch如何助力人工智能走向深度创新

    身为专注于人工智能研究的学者,我十分热衷于分析\\\"AI写稿\\\"与\\\"PyTorch\\\"这两项领先技术。面对日益精进的人工智能科技,\\\"AI写作\\\"已不再是天方夜谭;而\\\"PyTorch\\\"如璀璨明珠般耀眼,作为深度学习领域的尖端工具,正有力地推进着人工智能化进程。于此篇文章中,我将详细解析\\\"

    2024年04月13日
    浏览(57)
  • 人工智能学习07--pytorch14--ResNet网络/BN/迁移学习详解+pytorch搭建

    亮点:网络结构特别深 (突变点是因为学习率除0.1?) 梯度消失 :假设每一层的误差梯度是一个小于1的数,则在反向传播过程中,每向前传播一层,都要乘以一个小于1的误差梯度。当网络越来越深的时候,相乘的这些小于1的系数越多,就越趋近于0,这样梯度就会越来越小

    2023年04月11日
    浏览(159)
  • 人工智能学习07--pytorch15(前接pytorch10)--目标检测:FPN结构详解

    backbone:骨干网络,例如cnn的一系列。(特征提取) (a)特征图像金字塔 检测不同尺寸目标。 首先将图片缩放到不同尺度,针对每个尺度图片都一次通过算法进行预测。 但是这样一来,生成多少个尺度就要预测多少次,训练效率很低。 (b)单一特征图 faster rcnn所采用的一种方式

    2023年04月12日
    浏览(74)
  • 【人工智能概论】 PyTorch中的topk、expand_as、eq方法

    对PyTorch中的tensor类型的数据都存在topk方法,其功能是按照要求取前k个最大值。 其最常用的场合就是求一个样本被网络认为的前k种可能的类别。 举例: torch.topk(input, k, dim=None, largest=True, sorted=True, out=None) 其中: input: 是待处理的tensor数据; k: 指明要前k个数据及其index;

    2024年02月10日
    浏览(42)
  • PyTorch深度学习实战(2)——PyTorch基础

    PyTorch 是广泛应用于机器学习领域中的强大开源框架,因其易用性和高效性备受青睐。在本节中,将介绍使用 PyTorch 构建神经网络的基础知识。首先了解 PyTorch 的核心数据类型——张量对象。然后,我们将深入研究用于张量对象的各种操作。 PyTorch 提供了许多帮助构建神经网

    2024年02月09日
    浏览(41)
  • Pytorch深度学习 - 学习笔记

    dir() :打开,看见包含什么 help() :说明书 pytorch中读取数据主要涉及到两个类 Dataset 和 Dataloader 。 Dataset可以将可以使用的数据提取出来,并且可以对数据完成编号。即提供一种方式获取数据及其对应真实的label值。 Dataloader为网络提供不同的数据形式。 Dataset Dataset是一个抽

    2024年02月07日
    浏览(45)
  • 33- PyTorch实现分类和线性回归 (PyTorch系列) (深度学习)

    知识要点  pytorch 最常见的创建模型 的方式, 子类 读取数据: data = pd.read_csv (\\\'./dataset/credit-a.csv\\\', header=None) 数据转换为tensor: X = torch .from_numpy(X.values).type(torch.FloatTensor) 创建简单模型: 定义损失函数: loss_fn = nn.BCELoss () 定义优化器: opt = torch.optim.SGD (model.parameters(), lr=0.00001) 把梯度

    2024年02月06日
    浏览(50)
  • PyTorch深度学习实战(3)——使用PyTorch构建神经网络

    我们已经学习了如何从零开始构建神经网络,神经网络通常包括输入层、隐藏层、输出层、激活函数、损失函数和学习率等基本组件。在本节中,我们将学习如何在简单数据集上使用 PyTorch 构建神经网络,利用张量对象操作和梯度值计算更新网络权重。 1.1 使用 PyTorch 构建神

    2024年02月08日
    浏览(47)

觉得文章有用就打赏一下文章作者

支付宝扫一扫打赏

博客赞助

微信扫一扫打赏

请作者喝杯咖啡吧~博客赞助

支付宝扫一扫领取红包,优惠每天领

二维码1

领取红包

二维码2

领红包