说明
我参考了一个开源的人像语义分割项目mobile_phone_human_matting,这个项目提供了预训练模型,我想要将该模型固化,然后转换格式后在嵌入式端使用。
该项目保存模型的代码如下:
lastest_out_path = "{}/ckpt_lastest.pth".format(self.save_dir_model)
torch.save({
'epoch': epoch,
'state_dict': model.state_dict(),
}, lastest_out_path)
转换代码
上面代码保存了state_dict, 所以保存的文件中是不含模型结构的,固化时需要从代码构造网络结构。好在项目是完全开源,将原项目下的model目录拷贝过来就行。
另外不能忘记调用eval() 来固化参数。
完整的转换代码如下:
import torch
from model import segnet
ckptfile="./ckpt_lastest.pth"
savedfile="./human_seg.pt"
model = segnet.SegMattingNet()
device = torch.device('cpu')
ckpt = torch.load(ckptfile, map_location=device )
model.load_state_dict(ckpt['state_dict'])
model.eval() #这一步会将参数固化,不能省。否则会报AssertionError('batchnorm with training is not support. Please set model.eval() before export.')
x = torch.rand(1,3,256,256)
ts = torch.jit.trace(model, x)
ts.save(savedfile)
参考资料
mobile_phone_human_matting文章来源:https://www.toymoban.com/news/detail-687127.html
pytorch训练的.pth模型格式转换文章来源地址https://www.toymoban.com/news/detail-687127.html
到了这里,关于将pytorch的pth文件固化为pt文件的文章就介绍完了。如果您还想了解更多内容,请在右上角搜索TOY模板网以前的文章或继续浏览下面的相关文章,希望大家以后多多支持TOY模板网!