编译 MXNet 模型

这篇具有很好参考价值的文章主要介绍了编译 MXNet 模型。希望对大家有所帮助。如果存在错误或未考虑完全的地方,请大家不吝赐教,您也可以点击"举报违法"按钮提交疑问。

本篇文章译自英文文档 Compile MXNet Models。

作者是 Joshua Z. Zhang,Kazutaka Morita。

更多 TVM 中文文档可访问 →TVM 中文站。

本文将介绍如何用 Relay 部署 MXNet 模型。

首先安装 mxnet 模块,可通过 pip 快速安装:

pip install mxnet --user

或参考官方安装指南:https://mxnet.apache.org/versions/master/install/index.html

# 一些标准的导包
import mxnet as mx
import tvm
import tvm.relay as relay
import numpy as np

从 Gluon Model Zoo 下载 Resnet18 模型

本节会下载预训练的 imagenet 模型,并对图像进行分类。

from tvm.contrib.download import download_testdata
from mxnet.gluon.model_zoo.vision import get_model
from PIL import Image
from matplotlib import pyplot as plt

block = get_model("resnet18_v1", pretrained=True)
img_url = "https://github.com/dmlc/mxnet.js/blob/main/data/cat.png?raw=true"
img_name = "cat.png"
synset_url = "".join(
    [
        "https://gist.githubusercontent.com/zhreshold/",
        "4d0b62f3d01426887599d4f7ede23ee5/raw/",
        "596b27d23537e5a1b5751d2b0481ef172f58b539/",
        "imagenet1000_clsid_to_human.txt",
    ]
)
synset_name = "imagenet1000_clsid_to_human.txt"
img_path = download_testdata(img_url, "cat.png", module="data")
synset_path = download_testdata(synset_url, synset_name, module="data")
with open(synset_path) as f:
    synset = eval(f.read())
image = Image.open(img_path).resize((224, 224))
plt.imshow(image)
plt.show()

def transform_image(image):
    image = np.array(image) - np.array([123.0, 117.0, 104.0])
    image /= np.array([58.395, 57.12, 57.375])
    image = image.transpose((2, 0, 1))
    image = image[np.newaxis, :]
    return image

x = transform_image(image)
print("x", x.shape)

编译 MXNet 模型
输出结果:

Downloading /workspace/.mxnet/models/resnet18_v1-a0666292.zip08d19deb-ddbf-4120-9643-fcfab19e7541 from https://apache-mxnet.s3-accelerate.dualstack.amazonaws.com/gluon/models/resnet18_v1-a0666292.zip...
x (1, 3, 224, 224)

编译计算图

只需几行代码,即可将 Gluon 模型移植到可移植计算图上。mxnet.gluon 支持 MXNet 静态图(符号)和 HybridBlock。

shape_dict = {"data": x.shape}
mod, params = relay.frontend.from_mxnet(block, shape_dict)
## 添加 softmax 算子来提高概率
func = mod["main"]
func = relay.Function(func.params, relay.nn.softmax(func.body), None, func.type_params, func.attrs)

接下来编译计算图:

target = "cuda"
with tvm.transform.PassContext(opt_level=3):
    lib = relay.build(func, target, params=params)

输出结果:

/workspace/python/tvm/driver/build_module.py:268: UserWarning: target_host parameter is going to be deprecated. Please pass in tvm.target.Target(target, host=target_host) instead.
  "target_host parameter is going to be deprecated. "

在 TVM 上执行可移植计算图

接下来用 TVM 重现相同的前向计算:

from tvm.contrib import graph_executor

dev = tvm.cuda(0)
dtype = "float32"
m = graph_executor.GraphModule(lib["default"](dev))
# 设置输入
m.set_input("data", tvm.nd.array(x.astype(dtype)))
# 执行
m.run()
# 得到输出
tvm_output = m.get_output(0)
top1 = np.argmax(tvm_output.numpy()[0])
print("TVM prediction top-1:", top1, synset[top1])

输出结果:

TVM prediction top-1: 282 tiger cat

使用带有预训练权重的 MXNet 符号

MXNet 常用 arg_params 和 aux_params 分别存储网络参数,下面将展示如何在现有 API 中使用这些权重:

def block2symbol(block):
    data = mx.sym.Variable("data")
    sym = block(data)
    args = {}
    auxs = {}
    for k, v in block.collect_params().items():
        args[k] = mx.nd.array(v.data().asnumpy())
    return sym, args, auxs

mx_sym, args, auxs = block2symbol(block)
# 通常将其保存/加载为检查点
mx.model.save_checkpoint("resnet18_v1", 0, mx_sym, args, auxs)
# 磁盘上有 "resnet18_v1-0000.params" 和 "resnet18_v1-symbol.json"

对于一般性 MXNet 模型:

mx_sym, args, auxs = mx.model.load_checkpoint("resnet18_v1", 0)
# 用相同的 API 来获取 Relay 计算图
mod, relay_params = relay.frontend.from_mxnet(mx_sym, shape_dict, arg_params=args, aux_params=auxs)
# 重复相同的步骤,用 TVM 运行这个模型

下载 Python 源代码:from_mxnet.py

下载 Jupyter Notebook:from_mxnet.ipynb文章来源地址https://www.toymoban.com/news/detail-445770.html

到了这里,关于编译 MXNet 模型的文章就介绍完了。如果您还想了解更多内容,请在右上角搜索TOY模板网以前的文章或继续浏览下面的相关文章,希望大家以后多多支持TOY模板网!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处: 如若内容造成侵权/违法违规/事实不符,请点击违法举报进行投诉反馈,一经查实,立即删除!

领支付宝红包 赞助服务器费用

相关文章

  • ChatGPT专业应用:撰写英文SEO文章

    正文共  561  字,阅读大约需要  2  分钟 品牌营销/活动运营必备技巧,您将在2分钟后获得以下超能力: 快速生成英文SEO文章 Beezy评级 :B级 *经过简单的寻找, 大部分人能立刻掌握。主要节省时间。 推荐人  | Alice   编辑者  |  Linda ●此图片由Lexica 自动生成,输入:

    2024年02月13日
    浏览(43)
  • 【C++】String类基本接口介绍(多看英文文档)

    string目录 目录  如果你很赶时间,那么就直接看我本标题下的内容即可!! 一、STL简介 1.1什么是STL 1.2STL版本 1.3STL六大组件 1.4STL重要性 1.5如何学习STL 二、什么是string??(本质上是一个类) 三、string的类模板(什么?string居然利用了模板??) 三、string的三种构造(拷贝

    2024年02月07日
    浏览(41)
  • 全网最详细中英文ChatGPT-API文档(一)开始使用ChatGPT——导言

    The OpenAI API can be applied to virtually any task that involves understanding or generating natural language or code. We offer a spectrum of models with different levels of power suitable for different tasks, as well as the ability to fine-tune your own custom models. These models can be used for everything from content generation to semantic search and cl

    2023年04月25日
    浏览(50)
  • [Rust笔记] 为什么Rust英文文档普遍将【枚举值】记作variant而不是enum value?

    在阅读各类 Rust 英文技术资料时,你是否也曾经困惑过:为何每逢【枚举值】的概念出现时,作者都会以 variant 一词指代之?就字面含义而言, enum value 岂不是更贴切与易理解。简单地讲,这馁馁地是 Rust 技术优越性·宣传软文的广告梗,而且是很高端的内行梗。 Rustacean 们看

    2023年04月08日
    浏览(47)
  • 初入公司用不好git ?-- 本篇针对GitLab

    本篇并不涉及git的所有知识,内容包括工作中每天用到的以及需要知道的 一、从远程仓库拉取指定分支到本地仓库,并创建个人分支  二、(补充)基于以上补充几点基础知识点以便你更好理解并实践 1. 主分支:通常是master分支 2. 开发分支:基于主分支派生,你通常在这个

    2024年02月05日
    浏览(48)
  • NeMo中文/英文ASR模型微调训练实践

    1.安装nemo pip install -U nemo_toolkit[all] ASR-metrics 2.下载ASR预训练模型到本地(建议使用huggleface,比nvidia官网快很多) 3.从本地创建ASR模型 asr_model = nemo_asr.models.EncDecCTCModel.restore_from(\\\"stt_zh_quartznet15x5.nemo\\\") 3.定义train_mainfest,包含语音文件路径、时长和语音文本的json文件 4.读取模型的

    2024年02月13日
    浏览(40)
  • 【C语言】P166 10.有一篇文章,共有3行文字,每行有80个字符。要求分别统计出其中英文大写字母、小写字母、数字、空格以及其他字符的个数

    P166 10.有一篇文章,共有3行文字,每行有80个字符。要求分别 统计出其中英文大写字母、小写字母、数字、空格以及其他字符的个数   运行结果:  

    2024年02月04日
    浏览(65)
  • 利用PyTorch训练模型识别数字+英文图片验证码

    摘要:使用深度学习框架PyTorch来训练模型去识别4-6位数字+字母混合图片验证码(我们可以使用第三方库captcha生成这种图片验证码或者自己收集目标网站的图片验证码进行针对训练)。 一、制作训练数据集 我们可以把需要生成图片的一些参数放在setting.py文件中,方便以后更

    2024年04月15日
    浏览(45)
  • DedeCMS给文章添加“当前文档地址”和“转载说明”的方法

    在DedeCMS给文章添加“当前文档地址”和“转载说明”,文档内容结尾加一个转载说明,包含当前文档页面网址,如果文章被许多站长采集或转载,无疑可以增加很多外链! 下面来看看织梦CMS搭建的网站,如何添加这一功能? 这里,我们以DedeCMS的文章模型为例,其他模型类似

    2024年02月03日
    浏览(40)
  • 最强英文开源模型LLaMA架构探秘,从原理到源码

    导读: LLaMA 65B 是由Meta AI(原Facebook AI)发布并宣布开源的真正意义上的千亿级别大语言模型,发布之初(2023年2月24日)曾引起不小的轰动。LLaMA的横空出世,更像是模型大战中一个搅局者。虽然它的效果(performance)和GPT-4仍存在差距,但GPT-4毕竟是闭源的商业模型,LLaMA系列

    2024年02月04日
    浏览(41)

觉得文章有用就打赏一下文章作者

支付宝扫一扫打赏

博客赞助

微信扫一扫打赏

请作者喝杯咖啡吧~博客赞助

支付宝扫一扫领取红包,优惠每天领

二维码1

领取红包

二维码2

领红包