1、模型保存和加载
主要有两种情况:一是仅保存参数,二是保存参数及模型结构。
保存参数:
torch.save(net.state_dict())
加载参数(加载参数前需要先实例化模型):
param = torch.load('param.pth')
net.load_state_dict(param)
保存模型结构和参数:
torch.save(net)
加载模型:
net = torch.load('model.pt')
2、解析模型权重文件
当加载某个模型文件后,如果需要查看模型中的算子和参数,可以将模型解析为字典,然后逐一打印。
以lent5为例,将lenet5模型保存为权重文件,然后重新加载权重文件并解析其中每一层的参数。
参考代码:
def pytorch_params(pth_file):
par_dict = torch.load(pth_file, map_location='cpu')
for name in par_dict:
parameter = par_dict[name]
print(name, parameter.numpy().shape)
以上代码是加载的权重文件,文件只有参数,没有模型结构,如果加载的是包含模型结构的权重文件,可以做如下修改:
def pytorch_params(pt_file):
net = torch.load(pt_file, map_location='cpu')
par_dict = net.state_dict()
for name in par_dict:
parameter = par_dict[name]
print(name, parameter.numpy().shape)
解析结果:
3、加载自定义参数
某些情况下可能需要对某个算子进行单独调试,如加载特定参数进行推理计算,用来确定输出结果符合预期。以Conv2d算子为例进行测试,首先设定卷积层输入为3,输出为3,卷积核为3*3,偏置bias为False。通过numpy随机一个3*3*3*3的矩阵作为自定义参数,将参数转换为Tensor以后,添加到dict中,然后通过load_state_dict将参数加载进网络。
参考脚本:文章来源:https://www.toymoban.com/news/detail-573069.html
文章来源地址https://www.toymoban.com/news/detail-573069.html
import torch
import torch.nn as nn
import numpy as np
net = nn.Conv2d(3, 3, kernel_size=(3, 3), padding=1, bias=False)
param = np.random.random((3, 3, 3, 3))
param = param.astype(np.float32)
torch_param = {'weight': torch.Tensor(param)}
net.load_state_dict(torch_param)
net.eval()
data = np.random.random((1, 3, 16, 16))
data = data.astype(np.float32)
result = net(torch.Tensor(data))
print(result)
到了这里,关于pytorch保存、加载和解析模型权重的文章就介绍完了。如果您还想了解更多内容,请在右上角搜索TOY模板网以前的文章或继续浏览下面的相关文章,希望大家以后多多支持TOY模板网!