解决方案:炼丹师养成计划 Pytorch如何进行断点续训——DFGAN断点续训实操

这篇具有很好参考价值的文章主要介绍了解决方案:炼丹师养成计划 Pytorch如何进行断点续训——DFGAN断点续训实操。希望对大家有所帮助。如果存在错误或未考虑完全的地方,请大家不吝赐教,您也可以点击"举报违法"按钮提交疑问。

我们在训练模型的时候经常会出现各种问题导致训练中断,比方说断电、系统中断、内存溢出、断连、硬件故障、地震火灾等之类的导致电脑系统关闭,从而将模型训练中断。

所以在实际运行当中,我们经常需要每100轮epoch或者每50轮epoch要保存训练好的参数,以防不测,这样下次可以直接加载该轮epoch的参数接着训练,就不用重头开始。下面我们来介绍Pytorch断点续训原理及DFGAN20版本和22版本断点续训实操

文末评论【人生苦短,我用Pytorch!】抽一位小伙伴送出《PyTorch教程:21个项目玩转PyTorch实战》书籍一本,包邮到家。

一、Pytorch断点续训

1.1、保存模型

pytorch保存模型等相关参数,需要利用torch.save(),torch.save()是PyTorch框架中用于保存Python对象到磁盘上的函数,一般为

torch.save(checkpoint, checkpoint_path)

其中checkpoint为保存模型的所有参数和缓存的键值对,checkpoint_path表示最终保存的模型,通常以.pth格式保存。

torch.save()函数会将obj序列化为字节流,并将字节流写入f指定的文件中。在读取数据时,可以使用torch.load()函数来将文件中的字节流反序列化成Python对象。使用这两个函数可以轻松地将PyTorch模型保存到磁盘上,并在需要的时候重新加载使用。

一般在实际操作中,我们写为:

torch.save(netG.state_dict(),'%s/netG_epoch_%d.pth' % (self.model_dir, epoch))

它接受两个参数:要保存的对象(即状态字典)和文件路径。在这里,状态字典是通过调用netG.state_dict()方法获得的,而文件路径是使用字符串格式化操作构建的。字符串'%s/netG_epoch_%d.pth' % (self.model_dir, epoch) 中,%s表示第一个字符串占位符将被替换为self.model_dir(即保存.pth文件的目录路径),%d表示第二个字符串占位符将被替换为epoch(即当前训练的轮数)。这样就可以在每一轮训练结束后将当前的网络模型参数保存到一个新的.pth文件中,文件名中包含轮数以便于后续的查看和比较。

1.2、读取模型

对应的,torch.load()函数是PyTorch框架中用于从磁盘上加载Python对象的函数。一般为:

 checkpoint = torch.load(log_dir)
 model.load_state_dict(checkpoint['model'])

torch.load()函数会从文件中读取字节流,并将其反序列化成Python对象。对于PyTorch模型,可以直接将其反序列化成模型对象。

一般实际操作中,我们常常写为:

model.load_state_dict(torch.load(path))

首先使用torch.load()函数从指定的路径中加载模型参数,得到一个字典对象,即state_dict。其中,字典的键是各个层次结构的名称,而键所对应的值则是该层次结构中各个参数的值。

然后,使用model.load_state_dict()函数将state_dict中的参数加载到已经定义好的模型中。这个函数的作用是将state_dict中每个键所对应的参数加载到模型中对应的键所指定的层次结构上。

需要注意的是,由于模型的结构和保存的参数的结构必须匹配,因此在加载参数之前,需要先定义好模型的结构,使其与保存的参数的结构相同。如果结构不匹配,会导致加载参数失败,甚至会引发错误。

二、DFGAN20版本

在DFGAN20版本当中,模型保存在DFGAN/code/models当中,其中netG_300.pth就是代表生成器第300轮的模型netD_300.pth也就是代表鉴别器第300轮的模型。
解决方案:炼丹师养成计划 Pytorch如何进行断点续训——DFGAN断点续训实操

我们可以将需要的模型的路径记下来,然后打开main.py文件,其中在270行左右的# # validation data #下面
解决方案:炼丹师养成计划 Pytorch如何进行断点续训——DFGAN断点续训实操
可以在下面这段代码的后面

netG = NetG(cfg.TRAIN.NF, 100, sentencelstm, wordlstm).to(device)
netD = NetD(cfg.TRAIN.NF).to(device)

增加两句:

netG.load_state_dict(torch.load('models/%s/netG_300.pth' % (cfg.CONFIG_NAME)))
netD.load_state_dict(torch.load('models/%s/netD_300.pth' % (cfg.CONFIG_NAME)))

这样,就成功读取了所选文件夹目录下的netG_300.pthnetD_300.pth,如果要在这个epoch下进行采样,只需要把code/cfg/bird.ymlB_VALIDATION改为 True,如果需要在这个epoch下进行断点续训则B_VALIDATION改为False就可以了。

三、DFGAN22版本

DFGAN22版本与DFGAN20版本代码结构有所不同,但是在断点续训的原理上是一样的。

DFGAN22版本在保存模型时并没有单独保存netG, netD, netC, optG, optD等模型,而且将他们的模型都保存为一个.pth文件,如名为state_epoch_940.pth代表的就是第940轮的所有断点文件。这些断点文件保存在code/saved_models/bird或cooc下,如:
解决方案:炼丹师养成计划 Pytorch如何进行断点续训——DFGAN断点续训实操
如果要进行断点续训,我们可以把这个文件路径记下来或者将文件挪到需要的位置,我一般将需要断点续训或者采样的模型放在pretrained文件夹下。

然后下一步,打开code/cfg/bird.yml文件,如果是coco数据集则打开coco.yml
解决方案:炼丹师养成计划 Pytorch如何进行断点续训——DFGAN断点续训实操
修改state_epoch为自己选定的第几轮模型(想读取state_epoch_940.pth,则state_epoch改为940,这样后面打印结果、保存模型就是从941开始了),然后修改checkpoint为相应模型的路径如:./saved_models/bird/pretrained/state_epoch_940.pth,最终如下所示:

state_epoch: 940
checkpoint: ./saved_models/bird/pretrained/state_epoch_940.pth

如果你想更深层次了解其原理,即DFGAN22 版是如何保存模型和读取模型的,可以打开code/lib/utils.py文件,在第140行附近写了保存模型的函数,与我们之前讲的原理是一样的,只不过他将netG, netD, netC, optG, optD等又做了一层,然后将其统一保存到state_epoch_中:

def save_models(netG, netD, netC, optG, optD, epoch, multi_gpus, save_path):
    if (multi_gpus==True) and (get_rank() != 0):
        None
    else:
        state = {'model': {'netG': netG.state_dict(), 'netD': netD.state_dict(), 'netC': netC.state_dict()}, \
                'optimizers': {'optimizer_G': optG.state_dict(), 'optimizer_D': optD.state_dict()},\
                'epoch': epoch}
        torch.save(state, '%s/state_epoch_%03d.pth' % (save_path, epoch))

在第90行到140行附近,也写了读取模型的方法,也就是读相应checkpoint的checkpoint['model']['netG'],看完你会发觉,原理很简单,代码也不算很难,遇到问题建议大家多多阅读源码。

def load_opt_weights(optimizer, weights):
    optimizer.load_state_dict(weights)
    return optimizer


def load_model_opt(netG, netD, netC, optim_G, optim_D, path, multi_gpus=False):
    checkpoint = torch.load(path, map_location=torch.device('cpu'))
    netG = load_model_weights(netG, checkpoint['model']['netG'], multi_gpus)
    netD = load_model_weights(netD, checkpoint['model']['netD'], multi_gpus)
    netC = load_model_weights(netC, checkpoint['model']['netC'], multi_gpus)
    optim_G = load_opt_weights(optim_G, checkpoint['optimizers']['optimizer_G'])
    optim_D = load_opt_weights(optim_D, checkpoint['optimizers']['optimizer_D'])
    return netG, netD, netC, optim_G, optim_D


def load_models(netG, netD, netC, path):
    checkpoint = torch.load(path, map_location=torch.device('cpu'))
    netG = load_model_weights(netG, checkpoint['model']['netG'])
    netD = load_model_weights(netD, checkpoint['model']['netD'])
    netC = load_model_weights(netC, checkpoint['model']['netC'])
    return netG, netD, netC


def load_netG(netG, path, multi_gpus, train):
    checkpoint = torch.load(path, map_location="cpu")
    netG = load_model_weights(netG, checkpoint['model']['netG'], multi_gpus, train)
    return netG


def load_model_weights(model, weights, multi_gpus=False, train=True):
    if list(weights.keys())[0].find('module')==-1:
        pretrained_with_multi_gpu = False
    else:
        pretrained_with_multi_gpu = True
    if (multi_gpus==False) or (train==False):
        if pretrained_with_multi_gpu:
            state_dict = {
                key[7:]: value
                for key, value in weights.items()
            }
        else:
            state_dict = weights
    else:
        state_dict = weights
    model.load_state_dict(state_dict)
    return model

三、可能遇见的问题

问题1:模型中断后继续训练出错

在有些时候我们需要保存训练好的参数为path文件,以防不测,下次可以直接加载该轮epoch的参数接着训练,但是在重新加载时发现类似报错:

size mismatch for block0.affine0.linear1.linear2.weight: copying a param with shape torch.Size([512, 256]) from checkpoint, the shape in current model is torch.Size([256, 256]).
size mismatch for block0.affine0.linear1.linear2.bias: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([256]).

问题原因:这是说明某个超参数出现了问题,可能你之前训练时候用的是64,现在准备在另外的机器上面续训的时候某个超参数设置的是32,导致了size mismatch,也有可能是你动过了模型的代码,导致现在代码和训练的模型匹配不上了。

解决方案:查看size mismatch的模型部分,将超参数改回来,并将代码和原本训练的代码保持一致。

问题2:模型中断后继续训练 效果直降

加载该轮epoch的参数接着训练,继续训练的过程是能够运行的,但是发现继续训练时效果大打折扣,完全没有中断前的最后几轮好。
问题原因:暂时未知,推测是续训时模型加载的问题,也有可能是保存和加载的方式问题
解决方案:统一保存和加载的方式,当我采用以下方式时,貌似避免了这个问题:
模型的保存:

torch.save(netG.state_dict(), 'models/%s/netG_%03d.pth' % (cfg.CONFIG_NAME, epoch))

模型的重新加载:

netD.load_state_dict(torch.load('models/%s/netD_300.pth' % (cfg.CONFIG_NAME), map_location='cuda:0'))

四、好书推荐(评论送书)

4.1、好书推荐

《PyTorch教程:21个项目玩转PyTorch实战》
解决方案:炼丹师养成计划 Pytorch如何进行断点续训——DFGAN断点续训实操
阅读这本书,可以通过经典项目入门 PyTorch,通过前沿项目提升 PyTorch,基于PyTorch玩转深度学习,本书适合人工智能、机器学习、深度学习方面的人员阅读,也适合其他 IT 方面从业者,另外,还可以作为相关专业的教材。

京东自营购买链接:https://item.jd.com/13522327.html

评论区评论【人生苦短,我用Pytorch!】抽一位小伙伴送出《PyTorch教程:21个项目玩转PyTorch实战》书籍一本,包邮到家。文章来源地址https://www.toymoban.com/news/detail-408591.html

目前买还有优惠:北京大学出版社4月“423世界读书日”促销活动
当当活动日期:4.6-4.11,4.18-4.23
京东活动日期: 4.6 一天, 4.17-4.23
活动期间满100减50或者半价5折销售
希望大家关注参与423读书日北大社促销活动
京东自营购买链接:https://item.jd.com/13522327.html

4.2、内容简介

PyTorch 是基于 Torch 库的开源机器学习库,它主要由 Meta(原 Facebook)的人工智能研究实验室开发,在自然语言处理和计算机视觉领域都具有广泛的应用。本书介绍了简单且经典的入门项目,方便快速上手,如 MNIST数字识别,读者在完成项目的过程中可以了解数据集、模型和训练等基础概念。本书还介绍了一些实用且经典的模型,如 R-CNN 模型,通过这个模型的学习,读者可以对目标检测任务有一个基本的认识,对于基本的网络结构原理有一定的了解。另外,本书对于当前比较热门的生成对抗网络和强化学习也有一定的介绍,方便读者拓宽视野,掌握前沿方向。

4.3、作者简介

王飞,2019年翻译了PyTorch官方文档,读研期间研究方向为自然语言处理,主要是中文分词、文本分类和数据挖掘。目前在教育行业工作,探索人工智能技术在教育中的应用。
何健伟,曾任香港大学助理研究员,研究方向为自然语言处理,目前从事大规模推荐算法架构研究工作。
林宏彬,硕士期间研究方向为自然语言处理,现任阿里巴巴算法工程师,目前从事广告推荐领域的算法研究工作。
史周安,软件工程硕士,人工智能技术爱好者、实践者与探索者。目前从事弱监督学习、迁移学习与医学图像相关工作。

💡 最后

我们已经建立了🏤T2I研学社群,如果你还有其他疑问或者对🎓文本生成图像很感兴趣,可以私信我加入社群

📝 加入社群 抱团学习:中杯可乐多加冰-深度学习T2I研习群

🔥 限时免费订阅:文本生成图像T2I专栏

🎉 支持我:点赞👍+收藏⭐️+留言📝

评论区评论【人生苦短,我用Pytorch!】抽一位小伙伴送出《PyTorch教程:21个项目玩转PyTorch实战》书籍一本,包邮到家。

到了这里,关于解决方案:炼丹师养成计划 Pytorch如何进行断点续训——DFGAN断点续训实操的文章就介绍完了。如果您还想了解更多内容,请在右上角搜索TOY模板网以前的文章或继续浏览下面的相关文章,希望大家以后多多支持TOY模板网!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处: 如若内容造成侵权/违法违规/事实不符,请点击违法举报进行投诉反馈,一经查实,立即删除!

领支付宝红包 赞助服务器费用

相关文章

  • pytorch下载慢甚至下载失败怎么办?看看我的解决方案

    直接按官网命令下载torch文件太慢,有时候还可能下一半直接中断导致下载失败。。 我们到https://download.pytorch.org/whl/torch_stable.html这个网站里:  cu+序号后面表示cuda版本,即GPU版本(cpu+序号表示cpu),如cu117表示cuda 11.7; cp+序号表示python版本,如cp310表示python 3.10; 我们按自己

    2024年02月16日
    浏览(60)
  • Anaconda prompt中使用conda下载pytorch,一直卡在solving environment解决方案

    关闭梯子 清空镜像源: conda config --remove-key channels 在pytorch官网找到对应的版本与命令:PyTorch (我的电脑CUDA版本为12.1.103,故下载12.1的版本) 复制命令、下载 然后就是漫长的等待.. 参考 Conda - Downloaded bytes did not match Content-Length 问题解决方案_condaerror: downloaded bytes did not ma

    2024年02月07日
    浏览(48)
  • Server - PyTorch BFloat16 “TypeError: Got unsupported ScalarType BFloat16“ 解决方案

    欢迎关注我的CSDN:https://spike.blog.csdn.net/ 本文地址:https://spike.blog.csdn.net/article/details/132665807 BFloat16 类型是 16 位的浮点数格式,可以用来加速深度学习的计算和存储。BFloat16 类型的特点是保留 32 位浮点数(float32)的 8 位指数部分,但是只有 8 位的有效数字部分(而不是 fl

    2024年02月09日
    浏览(41)
  • 实验4 图的应用问题 给定n个村庄之间的交通图,现计划在n个村庄中选定一个村庄建造一所医院,请设计方案解决问题

    给定n个村庄之间的交通图,现计划在n个村庄中选定一个村庄建造一所医院,请设计方案解决如下问题: (1)求出该医院应建在哪个村庄,才能使距离最远的村庄到医院的路程最短; (2)求出该医院应建在哪个村庄,能使其它所有村庄到医院的路径总和最短。 对于如上所示

    2023年04月25日
    浏览(43)
  • 全套解决方案:基于pytorch、transformers的中文NLP训练框架,支持大模型训练和文本生成,快速上手,海量训练数据!

    目标 :基于 pytorch 、 transformers 做中文领域的nlp开箱即用的训练框架,提供全套的训练、微调模型(包括大模型、文本转向量、文本生成、多模态等模型)的解决方案; 数据 : 从开源社区,整理了海量的训练数据,帮助用户可以快速上手; 同时也开放训练数据模版,可以快

    2024年02月11日
    浏览(38)
  • 【CUDA driver initialization failed, you might not have a CUDA gpu】pytorch 解决方案

    在coding的时候我们经常在指定device的时候用这么一句代码: 但是有时候我们会发现device确实是放在了cpu上面,所以为了明确出错的原因,我们在shell里先import了torch,再执行 torch.cuda.is_available() ,发现在返回 False 结果之前给出了错误原因,其中部分内容就是我们在标题中写的

    2024年02月12日
    浏览(46)
  • 最佳解决方案:如何在网络爬虫中解决验证码

    Captcha(全自动区分计算机和人类的公开图灵测试)是广泛应用的安全措施,用于区分合法的人类用户和自动化机器人。它通过呈现复杂的挑战,包括视觉上扭曲的文本、复杂的图像或复杂的拼图等方式,要求用户成功解决这些挑战以验证其真实性。然而,在进行网络爬虫时,

    2024年01月23日
    浏览(37)
  • 如何快速制作解决方案PPT

    在开始制作解决方案PPT之前,需要对客户的需求进行深入了解和分析。这包括客户需要解决的问题、目标、预算和时间限制等。 需求分析 客户需要解决的问题 客户的目标 预算限制 时间限制 基于客户的需求,确定最优的解决方案。解决方案可以包括产品、服务或业务流程等

    2024年02月12日
    浏览(37)
  • pytorch版本不匹配导致的THC.h: No such file or directory 、THCCudaMalloc not defined等问题解决方案

    在论文复现安装maskrcnn-benchmark依赖项的过程中,遇见了pytorch版本不匹配导致的无法安装的问题,现存的大多数内容都建议安装低版本的pytorch以解决问题,但也不能总是这么干,不然自己这兼容性也太差了,顺便也吐槽一下pytorch的兼容性问题。在此总结一下,方便遇到相似问

    2023年04月18日
    浏览(56)
  • 基于Ascend910+PyTorch1.11.0+CANN6.3.RC2的YoloV5训练推理一体化解决方案

    昇腾Pytorch镜像:https://ascendhub.huawei.com/#/detail/ascend-pytorch 代码仓:git clone https://gitee.com/ascend/modelzoo-GPL.git coco测试验证集:wget https://bj-aicc.obs.cn-north-309.mtgascendic.cn/dataset/coco2017/coco.zip coco训练集(放images下):wget https://bj-aicc.obs.cn-north-309.mtgascendic.cn/dataset/coco2017/train2017.zip 部分

    2024年02月07日
    浏览(52)

觉得文章有用就打赏一下文章作者

支付宝扫一扫打赏

博客赞助

微信扫一扫打赏

请作者喝杯咖啡吧~博客赞助

支付宝扫一扫领取红包,优惠每天领

二维码1

领取红包

二维码2

领红包