Open Neural Network Exchange (ONNX,开放神经网络交换) 格式,是一个用于表示深度学习模型的标准,可使模型在不同框架之间进行转移
Torch 所定义的模型为动态图,其前向传播是由类方法定义和实现的
但是 Python 代码的效率是比较底下的,试想把动态图转化为静态图,模型的推理速度应当有所提升
Torch 框架中,torch.onnx.export 可以将父类为 nn.Module 的模型导出到 onnx 文件中,最重要的有三个参数:
- model:父类为 nn.Module 的模型
- args:传入 model 的 forward 方法的变量列表,类型应为 tuple
- f:onnx 文件名称的字符串
import torch
from torchvision.models import resnet50
file = 'resnet.onnx'
# 声明模型
resnet = resnet50(pretrained=False).eval()
image = torch.rand([1, 3, 224, 224])
# 导出为 onnx 文件
torch.onnx.export(resnet, (image,), file)
onnx 文件可被 Netron 打开,以查看模型结构
基本用法
要在 Python 中运行 onnx 模型,需要下载 onnxruntime
# 选其一即可
pip install onnxruntime # CPU 版本
pip install onnxruntime-gpu # GPU 版本
推理时需要借助其中的 InferenceSession,其中较为重要的实例方法有:
- get_inputs():得到输入变量的列表 (变量属性:name、shape、type)
- get_outputs():得到输入变量的列表 (变量属性:name、shape、type)
- run(output_names, input_feed):输入变量为 numpy.ndarray (注意 dtype 应为 float32),使用模型推理并返回输出
可得出 onnx 模型的基本用法:
import onnxruntime as ort
import numpy as np
file = 'resnet.onnx'
# 找到 GPU / CPU
provider = ort.get_available_providers()[
1 if ort.get_device() == 'GPU' else 0]
print('设备:', provider)
# 声明 onnx 模型
model = ort.InferenceSession(file, providers=[provider])
# 参考: ort.NodeArg
for node_list in model.get_inputs(), model.get_outputs():
for node in node_list:
attr = {'name': node.name,
'shape': node.shape,
'type': node.type}
print(attr)
print('-' * 60)
# 得到输入、输出结点的名称
input_node_name = model.get_inputs()[0].name
ouput_node_name = [node.name for node in model.get_outputs()]
image = np.random.random([1, 3, 224, 224]).astype(np.float32)
print(model.run(output_names=ouput_node_name,
input_feed={input_node_name: image}))
高级 API
为了简化使用步骤,使用类进行封装:
import numpy as np
import onnxruntime as ort
import torch.onnx
class timer:
def __init__(self, repeat: int = 1, avg: bool = True):
self.repeat = max(1, int(repeat) if isinstance(repeat, float) else repeat)
self.avg = avg
def __call__(self, func):
import time
def handler(*args, **kwargs):
t0 = time.time()
for i in range(self.repeat): func(*args, **kwargs)
cost = (time.time() - t0) * 1e3
return cost / self.repeat if self.avg else cost
return handler
def onnx_simplify(src, new=None):
''' onnx 模型简化'''
import onnxsim, onnx
model, check = onnxsim.simplify(onnx.load(src))
assert check, 'Failure to Simplify'
onnx.save(model, new if new else src)
class OnnxModel(ort.InferenceSession):
''' onnx 推理模型
provider: 优先使用 GPU'''
device = property(fget=lambda self: self.get_providers()[0][:-17])
def __init__(self, src):
for pvd in ort.get_available_providers():
try:
super().__init__(str(src), providers=[pvd])
break
except:
pass
assert self.get_providers(), 'No available Execution Providers were found'
# 参考: ort.NodeArg
self.io_node = list(map(list, (self.get_inputs(), self.get_outputs())))
self.io_name = [[n.name for n in nodes] for nodes in self.io_node]
self.io_shape = [[n.shape for n in nodes] for nodes in self.io_node]
def __call__(self, *inputs):
input_feed = {name: x for name, x in zip(self.io_name[0], inputs)}
return self.run(self.io_name[-1], input_feed)
@classmethod
def from_torch(cls, model, args, dst, test=False, **export_kwd):
args = (args,) if isinstance(args, torch.Tensor) else args
torch.onnx.export(model, args, dst, opset_version=11, **export_kwd)
onnx_model = cls(dst)
if test:
Timer = timer(repeat=3)
# 测试 Torch 的运行时间
torch_output = model(*args).data.numpy()
print(f'Torch: {Timer(model)(*args):.2f} ms')
# data: tensor -> array
args = tuple(map(lambda x: x.data.numpy(), args))
# 测试 onnx 的运行时间
onnx_output = onnx_model(*args)
print(f'Onnx: {Timer(onnx_model)(*args):.2f} ms')
# 计算 Torch 模型与 onnx 模型输出的绝对误差
abs_error = np.abs(torch_output - onnx_output).mean()
print(f'Mean Error: {abs_error:.2f}')
return onnx_model
在 Torch 中,对于卷积神经网络 model 与图像 image,推理的代码为 "model(image)",而使用这个封装的类也是类似:
import numpy as np
file = 'resnet.onnx'
model = OnnxModule(file)
image = np.random.random([1, 3, 224, 224]).astype(np.float32)
print(model(image))
from_torch 函数旨在导出 torch 模型的 onnx 文件,同时创建 onnx 推理模型
当其中的关键字参数 test = True 时,该函数会进行以下测试:文章来源:https://www.toymoban.com/news/detail-432108.html
- 得到 Torch 模型、onnx 模型的输出,并 print 推断耗时
- 计算 Torch 模型与 onnx 模型输出的绝对误差的均值
对于 ResNet50 而言,Torch 模型的推断耗时为 172.67 ms,onnx 模型的推断耗时为 36.56 ms,onnx 模型的推断耗时仅为 Torch 模型的 21.17%文章来源地址https://www.toymoban.com/news/detail-432108.html
到了这里,关于Torch 模型 onnx 文件的导出和调用的文章就介绍完了。如果您还想了解更多内容,请在右上角搜索TOY模板网以前的文章或继续浏览下面的相关文章,希望大家以后多多支持TOY模板网!