Torch 模型 onnx 文件的导出和调用

这篇具有很好参考价值的文章主要介绍了Torch 模型 onnx 文件的导出和调用。希望对大家有所帮助。如果存在错误或未考虑完全的地方,请大家不吝赐教,您也可以点击"举报违法"按钮提交疑问。

Open Neural Network Exchange (ONNX,开放神经网络交换) 格式,是一个用于表示深度学习模型的标准,可使模型在不同框架之间进行转移

Torch 所定义的模型为动态图,其前向传播是由类方法定义和实现的

但是 Python 代码的效率是比较底下的,试想把动态图转化为静态图,模型的推理速度应当有所提升

Torch 框架中,torch.onnx.export 可以将父类为 nn.Module 的模型导出到 onnx 文件中,最重要的有三个参数:

  • model:父类为 nn.Module 的模型
  • args:传入 model 的 forward 方法的变量列表,类型应为 tuple
  • f:onnx 文件名称的字符串
import torch
from torchvision.models import resnet50

file = 'resnet.onnx'
# 声明模型
resnet = resnet50(pretrained=False).eval()
image = torch.rand([1, 3, 224, 224])
# 导出为 onnx 文件
torch.onnx.export(resnet, (image,), file)

onnx 文件可被 Netron 打开,以查看模型结构

Torch 模型 onnx 文件的导出和调用

基本用法

要在 Python 中运行 onnx 模型,需要下载 onnxruntime

# 选其一即可
pip install onnxruntime        # CPU 版本
pip install onnxruntime-gpu    # GPU 版本

推理时需要借助其中的 InferenceSession,其中较为重要的实例方法有:

  • get_inputs():得到输入变量的列表 (变量属性:name、shape、type)
  • get_outputs():得到输入变量的列表 (变量属性:name、shape、type)
  • run(output_names, input_feed):输入变量为 numpy.ndarray (注意 dtype 应为 float32),使用模型推理并返回输出

可得出 onnx 模型的基本用法:

import onnxruntime as ort
import numpy as np

file = 'resnet.onnx'
# 找到 GPU / CPU
provider = ort.get_available_providers()[
    1 if ort.get_device() == 'GPU' else 0]
print('设备:', provider)
# 声明 onnx 模型
model = ort.InferenceSession(file, providers=[provider])

# 参考: ort.NodeArg
for node_list in model.get_inputs(), model.get_outputs():
    for node in node_list:
        attr = {'name': node.name,
                'shape': node.shape,
                'type': node.type}
        print(attr)
    print('-' * 60)

# 得到输入、输出结点的名称
input_node_name = model.get_inputs()[0].name
ouput_node_name = [node.name for node in model.get_outputs()]

image = np.random.random([1, 3, 224, 224]).astype(np.float32)
print(model.run(output_names=ouput_node_name,
                input_feed={input_node_name: image}))

高级 API

为了简化使用步骤,使用类进行封装:

import numpy as np
import onnxruntime as ort
import torch.onnx


class timer:

    def __init__(self, repeat: int = 1, avg: bool = True):
        self.repeat = max(1, int(repeat) if isinstance(repeat, float) else repeat)
        self.avg = avg

    def __call__(self, func):
        import time

        def handler(*args, **kwargs):
            t0 = time.time()
            for i in range(self.repeat): func(*args, **kwargs)
            cost = (time.time() - t0) * 1e3
            return cost / self.repeat if self.avg else cost

        return handler


def onnx_simplify(src, new=None):
    ''' onnx 模型简化'''
    import onnxsim, onnx
    model, check = onnxsim.simplify(onnx.load(src))
    assert check, 'Failure to Simplify'
    onnx.save(model, new if new else src)


class OnnxModel(ort.InferenceSession):
    ''' onnx 推理模型
        provider: 优先使用 GPU'''
    device = property(fget=lambda self: self.get_providers()[0][:-17])

    def __init__(self, src):
        for pvd in ort.get_available_providers():
            try:
                super().__init__(str(src), providers=[pvd])
                break
            except:
                pass
        assert self.get_providers(), 'No available Execution Providers were found'
        # 参考: ort.NodeArg
        self.io_node = list(map(list, (self.get_inputs(), self.get_outputs())))
        self.io_name = [[n.name for n in nodes] for nodes in self.io_node]
        self.io_shape = [[n.shape for n in nodes] for nodes in self.io_node]

    def __call__(self, *inputs):
        input_feed = {name: x for name, x in zip(self.io_name[0], inputs)}
        return self.run(self.io_name[-1], input_feed)

    @classmethod
    def from_torch(cls, model, args, dst, test=False, **export_kwd):
        args = (args,) if isinstance(args, torch.Tensor) else args
        torch.onnx.export(model, args, dst, opset_version=11, **export_kwd)
        onnx_model = cls(dst)
        if test:
            Timer = timer(repeat=3)
            # 测试 Torch 的运行时间
            torch_output = model(*args).data.numpy()
            print(f'Torch: {Timer(model)(*args):.2f} ms')
            # data: tensor -> array
            args = tuple(map(lambda x: x.data.numpy(), args))
            # 测试 onnx 的运行时间
            onnx_output = onnx_model(*args)
            print(f'Onnx: {Timer(onnx_model)(*args):.2f} ms')
            # 计算 Torch 模型与 onnx 模型输出的绝对误差
            abs_error = np.abs(torch_output - onnx_output).mean()
            print(f'Mean Error: {abs_error:.2f}')
        return onnx_model

在 Torch 中,对于卷积神经网络 model 与图像 image,推理的代码为 "model(image)",而使用这个封装的类也是类似:

import numpy as np

file = 'resnet.onnx'
model = OnnxModule(file)

image = np.random.random([1, 3, 224, 224]).astype(np.float32)
print(model(image))

from_torch 函数旨在导出 torch 模型的 onnx 文件,同时创建 onnx 推理模型

当其中的关键字参数 test = True 时,该函数会进行以下测试:

  • 得到 Torch 模型、onnx 模型的输出,并 print 推断耗时
  • 计算 Torch 模型与 onnx 模型输出的绝对误差的均值

对于 ResNet50 而言,Torch 模型的推断耗时为 172.67 ms,onnx 模型的推断耗时为 36.56 ms,onnx 模型的推断耗时仅为 Torch 模型的 21.17%文章来源地址https://www.toymoban.com/news/detail-432108.html

到了这里,关于Torch 模型 onnx 文件的导出和调用的文章就介绍完了。如果您还想了解更多内容,请在右上角搜索TOY模板网以前的文章或继续浏览下面的相关文章,希望大家以后多多支持TOY模板网!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处: 如若内容造成侵权/违法违规/事实不符,请点击违法举报进行投诉反馈,一经查实,立即删除!

领支付宝红包 赞助服务器费用

相关文章

  • yolov5 pt 模型 导出 onnx

    在训练好的yolov5 pt 模型 可以 通过 export.py 进行导出 onnx 导出流程 在 export.py 设置模型和数据源的yaml 在官方的文档中 说明了可以导出的具体的类型。 在 --include 添加导出的类型, 不同的 类型的 环境要求不一样,建议虚拟环境,比如onnx 和 openvino 的numpy 版本要求不一只,一

    2024年02月11日
    浏览(42)
  • 导出LLaMA等LLM模型为onnx

    通过onnx模型可以在支持onnx推理的推理引擎上进行推理,从而可以将LLM部署在更加广泛的平台上面。此外还可以具有避免pytorch依赖,获得更好的性能等优势。 这篇博客(大模型LLaMa及周边项目(二) - 知乎)进行了llama导出onnx的开创性的工作,但是依赖于侵入式修改transform

    2024年02月14日
    浏览(81)
  • 自然语言处理从入门到应用——静态词向量预训练模型:神经网络语言模型(Neural Network Language Model)

    分类目录:《自然语言处理从入门到应用》总目录 《自然语言处理从入门到应用——自然语言处理的语言模型(Language Model,LM)》中介绍了语言模型的基本概念,以及经典的基于离散符号表示的N元语言模型(N-gram Language Model)。从语言模型的角度来看,N元语言模型存在明显

    2024年02月09日
    浏览(54)
  • 导出LLaMA ChatGlm2等LLM模型为onnx

    通过onnx模型可以在支持onnx推理的推理引擎上进行推理,从而可以将LLM部署在更加广泛的平台上面。此外还可以具有避免pytorch依赖,获得更好的性能等优势。 这篇博客(大模型LLaMa及周边项目(二) - 知乎)进行了llama导出onnx的开创性的工作,但是依赖于侵入式修改transform

    2024年02月13日
    浏览(42)
  • OpenMMlab导出mobilenet-v2的onnx模型并推理

    使用mmpretrain导出mobilenet-v2的onnx模型: 安装有mmdeploy的话可以通过如下方法导出: 通过onnxruntime进行推理: 使用mmdeploy推理: 或者 这里通过trtexec转换onnx文件,LZ的版本是TensorRT-8.2.1.8。 使用mmdeploy推理: 或者

    2024年02月05日
    浏览(51)
  • Pytorch复习笔记--导出Onnx模型为动态输入和静态输入

    目录 1--动态输入和静态输入 2--Pytorch API 3--完整代码演示 4--模型可视化 5--测试动态导出的Onnx模型         当使用 Pytorch 将网络导出为 Onnx 模型格式时,可以导出为动态输入和静态输入两种方式。动态输入即模型输入数据的部分维度是动态的,可以由用户在使用模型时自主设

    2024年01月20日
    浏览(42)
  • 深度学习-Python调用ONNX模型

    目录 ONNX模型使用流程 获取ONNX模型方法 使用ONNX模型 手动编写ONNX模型 Python调用ONNX模型 常见错误 错误raise ValueError...: 错误:Load model model.onnx failed 错误:\\\'CUDAExecutionProvider\\\' is not in available provider 错误:ONNXRuntimeError 错误:\\\'CUDAExecutionProvider\\\' is not in available provider ONNX(Open Ne

    2024年02月05日
    浏览(49)
  • 如何在windows系统下将yolov5的pt模型导出为onnx模型

    最近在做本科毕业设计,要求是在树莓派上部署yolo算法来实现火灾检测,在网上查了很多资料,最后选择用yolov5s模型先试着在树莓派上部署,看下效果如何,由于从大佬那里拿到了yolov5火灾检测模型,但想要将它移植到树莓派上第一步要把pt模型转换成onnx模型,原因我想大

    2023年04月12日
    浏览(45)
  • C++调用yolov5 onnx模型的初步探索

    yolov5-dnn-cpp-python https://github.com/hpc203/yolov5-dnn-cpp-python 转onnx: 用opencv的dnn模块做yolov5目标检测的程序,包含两个步骤:(1).把pytorch的训练模型.pth文件转换到.onnx文件。(2).opencv的dnn模块读取.onnx文件做前向计算。 SiLU其实就是swish激活函数,而在onnx模型里是不直接支持swish算子的

    2024年02月12日
    浏览(37)
  • 【深度学习】【Opencv】【CPU】Python/C++调用onnx模型【基础】

    提示:博主取舍了很多大佬的博文并亲测有效,分享笔记邀大家共同学习讨论 OpenCV是一个基于BSD许可发行的跨平台计算机视觉和机器学习软件库(开源),可以运行在Linux、Windows、Android和Mac OS操作系统上。可以将pytorch中训练好的模型使用ONNX导出,再使用opencv中的dnn模块直接进行

    2024年02月04日
    浏览(46)

觉得文章有用就打赏一下文章作者

支付宝扫一扫打赏

博客赞助

微信扫一扫打赏

请作者喝杯咖啡吧~博客赞助

支付宝扫一扫领取红包,优惠每天领

二维码1

领取红包

二维码2

领红包