LibTorch实战四:模型序列化

这篇具有很好参考价值的文章主要介绍了LibTorch实战四:模型序列化。希望对大家有所帮助。如果存在错误或未考虑完全的地方,请大家不吝赐教,您也可以点击"举报违法"按钮提交疑问。

目录

在C++环境中加载一个TORCHSCRIP

Step1: 将pytorch模型转为torch scrip类型的模型

1.1、基于Tracing的方法来转换为Torch Script

1.2、基于Annotating (Script)的方法来转换为Torch Script

Step2: 序列化torch.jit.ScriptModule类型的对象,并保存为文件

Step3: 在libtorch中加载ScriptModule模型

总结


在C++环境中加载一个TORCHSCRIP

        一般地,类似python的脚本语言可用于算法快速实现、验证;但在产品化过程中,一般采用效率更高的C++语言,下面的工作就是将模型从python环境中移植到c++环境。

Step1: 将pytorch模型转为torch scrip类型的模型

通过TorchSript,我们可将pytorch模型从python转为c++。那么,什么是TorchScript呢?其实,它也是Pytorch模型的一种,这种模型能够被TorchScript的编译器识别读取、序列化。一般地,在处理模型过程中,我们都会先将模型转为torch script格式,例如:".pt"  -> "yolov5x.torchscript.pt"。转为torchscript格式有两种方法:一是函数torch.jit.trace;二是函数torch.jit.script。

torch.jit.trace原理:基于跟踪机制,需要输入一张图(0矩阵、张量亦可),模型会对输入的tensor进行处理,并记录所有张量的操作,torch::jit::trace能够捕获模型的结构、参数并保存。由于跟踪仅记录张量上的操作,因此它不会记录任何控制流操作,如if语句或循环。

torch.jit.script原理:需要开发者先定义好神经网络模型结构,即:提前写好 class MyModule(torch.nn.Module),这样TorchScript可以根据定义好的MyModule来解析网络结构。

1.1、基于Tracing的方法来转换为Torch Script

如下代码,给 torch.jit.trace 函数输入一个指定size的随机张量、ResNet18的网络模型,得到一个类型为 torch.jit.ScriptModule 的对象,即:traced_script_module

import torch
import torchvision

# An instance of your model.
model = torchvision.models.resnet18()

# An example input you would normally provide to your model's forward() method.
example = torch.rand(1, 3, 224, 224)

# Use torch.jit.trace to generate a torch.jit.ScriptModule via tracing.
traced_script_module = torch.jit.trace(model, example)

经过上述处理,traced_script_module变量已经包含网络的结构和参数,可以直接用于推理,如下代码:

1 In[1]: output = traced_script_module(torch.ones(1, 3, 224, 224))
2 In[2]: output[0, :5]
3 Out[2]: tensor([-0.2698, -0.0381,  0.4023, -0.3010, -0.0448], grad_fn=<SliceBackward>)

1.2、基于Annotating (Script)的方法来转换为Torch Script

如果你的模型中有类似于控制流操作(例如:if or for循环),基于上述tracing的方式不再适用,这种方式会排上用场,下面以vanilla模型为例子,注:下面网络结构中有个if判断。

# 定义一个vanilla模型
import torch

class MyModule(torch.nn.Module):
    def __init__(self, N, M):
        super(MyModule, self).__init__()
        self.weight = torch.nn.Parameter(torch.rand(N, M))

    def forward(self, input):
        if input.sum() > 0:
          output = self.weight.mv(input)
        else:
          output = self.weight + input
        return output

这里调用 torch.jit.script 来获取 torch.jit.ScriptModule 类型的对象,即:sm

class MyModule(torch.nn.Module):
    def __init__(self, N, M):
        super(MyModule, self).__init__()
        self.weight = torch.nn.Parameter(torch.rand(N, M))

    def forward(self, input):
        if input.sum() > 0:
          output = self.weight.mv(input)
        else:
          output = self.weight + input
        return output

my_module = MyModule(10,20)
sm = torch.jit.script(my_module)

Step2: 序列化torch.jit.ScriptModule类型的对象,并保存为文件

注:上述的tacing和script方法都将得到一个类型为torch.jit.ScriptModule的对象(这里简单记为:ScriptModule ),该对象就是常规的前向传播模块。不管是哪一种方法,此时,只需要将ScriptModule进行序列化保存就行。这里保存的是上述基于Tracing得到的ResNet推理模块traced_script_module。

traced_script_module.save("traced_resnet_model.pt") # 序列化,保存
# 保存后可用工具:https://netron.app/ 进行可视化

同理,如下是保存基于Annotating得到推理模块my_module 后续,在libtorch中加载上述保存的模型文件就行,不再依赖任何python包。

my_module.save("my_module_model.pt") # 为什么不是sm

Step3: 在libtorch中加载ScriptModule模型

 如何配置libtorh?,我这里仅贴下vs环境下的属性表:

1 include:
2 D:\ThirdParty\libtorch-win-shared-with-deps-1.7.1+cu110\libtorch\include
4 D:\ThirdParty\libtorch-win-shared-with-deps-1.7.1+cu110\libtorch\include\torch\csrc\api\include
7 lib:
8 D:\ThirdParty\libtorch-win-shared-with-deps-1.7.1+cu110\libtorch\lib
9 
11 链接器:
12 c10.lib
13 c10_cuda.lib
14 torch.lib
15 torch_cpu.lib
16 torch_cuda.lib
17 
18 环境变量:
19 D:\ThirdParty\libtorch-win-shared-with-deps-1.7.1+cu110\libtorch\lib

以下c++代码加载上述模型文件

#include<torch/torch.h>
#include<torch/script.h>
#include<iostream>
#include<memory>

int main()
{
    torch::jit::script::Module module;
    std::string str = "traced_resnet_model.pt";
    try
    {
        module = torch::jit::load(str);
    }
    catch (const c10::Error& e)
    {
        std::cerr << "12313";
        return -1;
    }

    // 创建一个输入
    std::vector<torch::jit::IValue> inputs;
    inputs.push_back(torch::ones({ 1, 3, 224, 224 }));
    // 推理
    at::Tensor output = module.forward(inputs).toTensor();
    std::cout << output.slice(/*dim=*/1, /*start=*/0, /*end=*/5) << '\n';

    return 1;
}

总结

python模型的序列化、保存代码:

import torchvision
import torch

model = torchvision.models.resnet18()

example = torch.rand(1, 3, 224, 224)

traced_script_module = torch.jit.trace(model, example)

output = traced_script_module(torch.ones(1, 3, 224, 224))

#traced_script_module.save("traced_resnet_model.pt") # 和下面等价,格式名称不同,仅此而已,在libtorch中是一样的
traced_script_module.save("traced_resnet_model.torchscript.pt")
print()

libtorch的模型加载,推理代码:文章来源地址https://www.toymoban.com/news/detail-741080.html

#include<torch/torch.h>
#include<torch/script.h>
#include<iostream>
#include<memory>

int main()
{
    torch::jit::script::Module module;
    std::string str = "traced_resnet_model.pt";
    //std::string str = "traced_resnet_model.torchscript.pt"; // 和上面等价,模型格式而已
    try
    {
        module = torch::jit::load(str);
    }
    catch (const c10::Error& e)
    {
        std::cerr << "12313";
        return -1;
    }

    // 创建一个输入
    std::vector<torch::jit::IValue> inputs;
    inputs.push_back(torch::ones({ 1, 3, 224, 224 }));
    // 推理
    at::Tensor output = module.forward(inputs).toTensor();
    std::cout << output.slice(/*dim=*/1, /*start=*/0, /*end=*/5) << '\n';

    return 1;
}

到了这里,关于LibTorch实战四:模型序列化的文章就介绍完了。如果您还想了解更多内容,请在右上角搜索TOY模板网以前的文章或继续浏览下面的相关文章,希望大家以后多多支持TOY模板网!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处: 如若内容造成侵权/违法违规/事实不符,请点击违法举报进行投诉反馈,一经查实,立即删除!

领支付宝红包 赞助服务器费用

相关文章

  • 【序列化与反序列化】关于序列化与反序列化MessagePack的实践

    在进行序列化操作之前,我们还对系统进行压测,通过 jvisualvm 分析cpu,线程,垃圾回收情况等;运用火焰图 async-profiler 分析系统性能,找出程序中占用CPU资源时间最长的代码块。 代码放置GitHub:https://github.com/nateshao/leetcode/tree/main/source-code/src/main/java/com/nateshao/source/code/ser

    2024年02月11日
    浏览(47)
  • 【网络】序列化反序列化

    在前文《网络编程套接字》中,我们实现了服务器与客户端之间的字符串通信,这是非常简单的通信,在实际使用的过程中,网络需要传输的不仅仅是字符串,更多的是结构化的数据(类似于 class , struct 类似的数据)。 那么我们应该怎么发送这些结构化的数据呢? 如果我们

    2024年02月05日
    浏览(35)
  • 序列化,反序列化之实例

    介绍文章 __construct() 当一个对象创建时自动调用 __destruct() 当对象被销毁时自动调用 (php绝大多数情况下会自动调用销毁对象) __sleep() 使**用serialize()函数时触发 __wakeup 使用unserialse()**函数时会自动调用 __toString 当一个对象被当作一个字符串被调用 __call() 在对象上下文中调用不

    2024年02月14日
    浏览(36)
  • 协议,序列化,反序列化,Json

    协议究竟是什么呢?首先得知道主机之间的网络通信交互的是什么数据,像平时使用聊天APP聊天可以清楚,用户看到的不仅仅是聊天的文字,还能够看到用户的头像昵称等其他属性。也就可以证明网络通信不仅仅是交互字符串那么简单。事实上网络通信还可能会通过一个结构

    2024年02月13日
    浏览(29)
  • 【网络】协议定制+序列化/反序列化

    如果光看定义很难理解序列化的意义,那么我们可以从另一个角度来推导出什么是序列化, 那么究竟序列化的目的是什么? 其实序列化最终的目的是为了对象可以 跨平台存储,和进行网络传输 。而我们进行跨平台存储和网络传输的方式就是IO,而我们的IO支持的数据格式就是

    2024年02月08日
    浏览(35)
  • Qt 对象序列化/反序列化

    阅读本文大概需要 3 分钟 日常开发过程中,避免不了对象序列化和反序列化,如果你使用 Qt 进行开发,那么有一种方法实现起来非常简单和容易。 我们知道 Qt 的元对象系统非常强大,基于此属性我们可以实现对象的序列化和反序列化操作。 比如有一个学生类,包含以下几

    2024年02月13日
    浏览(31)
  • 什么是序列化和反序列化?

    JSON(JavaScript Object Notation)和XML(eXtensible Markup Language)是两种常用的数据交换格式,用于在不同系统之间传输和存储数据。 JSON是一种轻量级的数据交换格式,它使用易于理解的键值对的形式表示数据。JSON数据结构简单明了,易于读写和解析,是基于JavaScript的一种常用数据

    2024年02月09日
    浏览(44)
  • Spring Boot 序列化、反序列化

    在软件开发中,序列化和反序列化是一种将对象转换为字节流以便存储或传输的机制。序列化将对象转换为字节流,而反序列化则将字节流转换为对象。序列化和反序列化在许多应用场景中都起着重要的作用,比如在网络通信中传输对象、将对象存储到数据库中、实现分布式

    2024年02月15日
    浏览(30)
  • jackson自定义序列化反序列化

    自定义序列化 序列化主要作用在返回数据的时候 以BigDecimal统一返回3位小数为例 自定义序列化处理类 继承jackson的 JsonSerializer 类,重写 serialize 方法 使用的时候,可以直接使用Jackson的 @JsonSerialize 注解 自定义反序列化 接收前端传入数据 继承 JsonDeserializer 类,重写 deserializ

    2024年02月13日
    浏览(35)
  • 协议定制 + Json序列化反序列化

    1.1 结构化数据 协议是一种 “约定”,socket api的接口, 在读写数据时,都是按 “字符串” 的方式来发送接收的。如果我们要传输一些\\\"结构化的数据\\\" 怎么办呢? 结构化数据: 比如我们在QQ聊天时,并不是单纯地只发送了消息本身,是把自己的头像、昵称、发送时间、消息本身

    2024年02月09日
    浏览(35)

觉得文章有用就打赏一下文章作者

支付宝扫一扫打赏

博客赞助

微信扫一扫打赏

请作者喝杯咖啡吧~博客赞助

支付宝扫一扫领取红包,优惠每天领

二维码1

领取红包

二维码2

领红包