TVM 源码阅读PASS — VectorizeLoop

这篇具有很好参考价值的文章主要介绍了TVM 源码阅读PASS — VectorizeLoop。希望对大家有所帮助。如果存在错误或未考虑完全的地方,请大家不吝赐教,您也可以点击"举报违法"按钮提交疑问。

本文地址:https://www.cnblogs.com/wanger-sjtu/p/17501119.html

VectorizeLoop这个PASS就是对标记为ForKind::kVectorizedFor循环做向量化处理,并对For循环中的语句涉及到的变量,替换为Ramp,以便于在Codegen的过程中生成相关的向量化运算的指令。

VectorizeLoop这个PASS的入口函数如下,只有在打开enable_vectorize=true的情况下载才会被启用,否则VectorizeSkipper会把ForKind::kVectorizedFor循环替换为普通循环。

Pass VectorizeLoop(bool enable_vectorize) {
  auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) {
    auto* n = f.CopyOnWrite();
    if (enable_vectorize) {
      n->body = LoopVectorizer()(std::move(n->body));
    } else {
      n->body = VectorizeSkipper()(std::move(n->body));
    }
    return f;
  };
  return CreatePrimFuncPass(pass_func, 0, "tir.VectorizeLoop", {});
}

下面就以UT中的几个例子,介绍一下源码实现。

vectorize_loop

dtype = "int64"
n = te.var("n")
ib = tvm.tir.ir_builder.create()
A = ib.pointer("float32", name="A")

with ib.for_range(0, n) as i:
 with ib.for_range(0, 4, kind="vectorize") as j:
     A[i*4+j] += tvm.tir.const(1, A.dtype)
stmt = ib.get()
assert isinstance(stmt.body, tvm.tir.For)
mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([A, n], stmt))
stmt = tvm.tir.transform.VectorizeLoop()(mod)["main"].body

上面的这个代码完成的是,向量加法,长度为4n的向量A,对每个元素+1。

# before
for (i, 0, n) {
  vectorized (j, 0, 4) {
    A[((i*4) + j)] = (A[((i*4) + j)] + 1f)
  }
}
# after
for (i, 0, n) {
  A[ramp((i*4), 1, 4)] = (A[ramp((i*4), 1, 4)] + x4(1f))
}

可以看到在经过VectorizeLoop的PASS以后,内层的循环消掉了,替换成为了一个Ramp的向量指令,这个在CPU中会被替换为SIMD指令(neon,AVX等)

PASS流程

在向量化的处理的PASS中是在LoopVectorizer中处理的,处理For循环部分。

class LoopVectorizer : public StmtMutator {
 public:
  Stmt VisitStmt_(const ForNode* op) final {
    if (op->kind == ForKind::kVectorized) {
      ICHECK(is_zero(op->min));
      auto* extent_as_int = op->extent.as<IntImmNode>();
      if (!extent_as_int || extent_as_int->value < 1) {
        LOG(FATAL) << "Failed to vectorize loop with extent " << op->extent;
      }
      return Vectorizer(op->loop_var, static_cast<int>(extent_as_int->value))(op->body);
    } else {
      return StmtMutator::VisitStmt_(op);
    }
  }
};

当遇到需要向量化的节点时,首先记录循环变量和范围,这个在后续替换相应的Load和Store操作为Ramp时用到。然后就到了Vectorizer部分,遍历For循环体,修改相应的stmt。

Vectorizer(Var var, int var_lanes) : var_(var), var_lanes_(var_lanes) {
    ramp_ = Ramp(0, 1, var_lanes);
}

在Vectorizer中对不同的PrimExprStmt做了重载。这里不逐一介绍,就以上面的向量加计算,介绍一下用到的函数以及流程。

首先看一下这里的上面sch的For的循环内的计算逻辑:

 A[((i*4) + j)] = (A[((i*4) + j)] + 1f)

因为TVM中,Stmt的表达可以视为一个DSL的语言,访问的时候也是按照深度优先的策略遍历的AST,这里把上面的计算过程简单表示为一个AST的语法树,然后再分析一下流程中调用的各个函数是如何处理的。

TVM 源码阅读PASS — VectorizeLoop

从上面的AST的示意图可以看出来,对于上面的sch,依次访问了BufferStoreNodeAdd MulBufferLoadNode 等。这里就以这几个Node的处理介绍一下向量化的过程。

所谓向量化的过程就是把这个标记为kVectorized的标量循环操作映射到向量化的操作,对于上面的例子来说就是把所有关于j的访问映射为RampNode,以便于后续处理可以正确生成相应的指令。

BufferStoreNode

BufferStoreNode中有三部分:

  • buffer——写入的buffer
  • value——待写入的值或者表达式
  • indices——写入buffer的坐标
    这里的目的就是修改valueindices中的内容。
    对于indices,是在这里完成的。最终通过MapHelper依次访问了indices的表达式。
auto fmutate = [this](const PrimExpr& index) { return this->VisitExpr(index); };
Array<PrimExpr> indices = op->indices.Map(fmutate);

对于value 则是直接遍历。

PrimExpr value = this->VisitExpr(op->value);
AddNode

对于AddNodeSubNode 都会走到AddSubVec这个模板函数。
这个函数里面首先会遍历左右表达式,

PrimExpr a = this->VisitExpr(op->a);
PrimExpr b = this->VisitExpr(op->b);
if (a.same_as(op->a) && b.same_as(op->b)) {
 return GetRef<PrimExpr>(op);
} else {
int lanes = std::max(a.dtype().lanes(), b.dtype().lanes());
if (lanes != 1) {
 const RampNode* b_ramp = b.as<RampNode>();
 const RampNode* a_ramp = a.as<RampNode>();
 if (a.dtype().lanes() == 1 && b_ramp) {
   return Ramp(fcompute(a, b_ramp->base),
		 fcompute(make_zero(b_ramp->stride.dtype()), b_ramp->stride), b_ramp->lanes);
 }
 if (b.dtype().lanes() == 1 && a_ramp) {
   return Ramp(fcompute(a_ramp->base, b), a_ramp->stride, a_ramp->lanes);
 }
}
return fcompute(BroadcastTo(a, lanes), BroadcastTo(b, lanes));

如果遍历之后没有变化,就直接返回了。而对于这里的我们需要计算的是

((i*4) + j)

j 是需要向量化的坐标。i*4 是没有变化的。遍历以后a没变化,b变成了T.Ramp(0, 1, 4) 这时候lanes=4,会走到第一个if分支,返回的是新构造的RampNode

 T.Ramp(i * 4, 1, 4)

其他的分支也类似。比如:

A[i * 4 + j] + T.float32(1)
// --- after ---
A[i * 4:i * 4 + 4]   T.float32(1)

这里会把a、b broadcast为一个向量再做计算。

VarNode

对于这里的VarNode判断就比较简单了,如果匹配到的是需要向量化的变量,就返回构造函数中构造的RampNode,否则就返回。其他的操作,暂时略过。

Var var = GetRef<Var>(op);
if (var.same_as(var_)) {
 return ramp_;
}
// ...
else {
 return std::move(var);
}
MulNode
PrimExpr a = this->VisitExpr(op->a);
PrimExpr b = this->VisitExpr(op->b);
if (a.same_as(op->a) && b.same_as(op->b)) {
return GetRef<PrimExpr>(op);
} else {
int lanes = std::max(a.dtype().lanes(), b.dtype().lanes());
if (lanes != 1) {
 const RampNode* b_ramp = b.as<RampNode>();
 const RampNode* a_ramp = a.as<RampNode>();
 if (a_ramp && b.dtype().lanes() == 1 && analyzer_.CanProve(b > 0)) {
   return Ramp(a_ramp->base * b, a_ramp->stride * b, a_ramp->lanes);
 }
 if (b_ramp && a.dtype().lanes() == 1 && analyzer_.CanProve(a > 0)) {
   return Ramp(b_ramp->base * a, b_ramp->stride * a, b_ramp->lanes);
 }
}
return Mul(BroadcastTo(a, lanes), BroadcastTo(b, lanes));
}
return BinaryVec<Mul>(op);

这里的处理逻辑与Add基本一致。只是在计算RampNode的时候有点区别。文章来源地址https://www.toymoban.com/news/detail-498122.html

到了这里,关于TVM 源码阅读PASS — VectorizeLoop的文章就介绍完了。如果您还想了解更多内容,请在右上角搜索TOY模板网以前的文章或继续浏览下面的相关文章,希望大家以后多多支持TOY模板网!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处: 如若内容造成侵权/违法违规/事实不符,请点击违法举报进行投诉反馈,一经查实,立即删除!

领支付宝红包 赞助服务器费用

相关文章

  • Linux系列文章 —— vim的基本操作(误入vim退出请先按「ESC」再按:q不保存退出,相关操作请阅读本文)

    vim-操作篇 进程概念篇 进程地址空间篇 Linux,是一种免费使用和自由传播的类UNIX操作系统,是一个基于POSIX的多用户、多任务、支持多线程和多CPU的操作系统。它能运行主要的Unix工具软件、应用程序和网络协议。Linux继承了Unix以网络为核心的设计思想,是一个性能稳定的多用

    2024年02月03日
    浏览(46)
  • 轻量封装WebGPU渲染系统示例<7>-材质多pass(源码)

    当前示例源码github地址: https://github.com/vilyLei/voxwebgpu/blob/feature/rendering/src/voxgpu/sample/MultiMaterialPass.ts 此示例渲染系统实现的特性: 1. 用户态与系统态隔离。          细节请见:引擎系统设计思路 - 用户态与系统态隔离-CSDN博客 2. 高频调用与低频调用隔离。 3. 面向用户的易

    2024年02月08日
    浏览(93)
  • 含源码|基于MATLAB的去雾系统(5种去雾算法+1种本文的改进算法)

    去雾系统V2包括作者新加入的 多尺度Retinex去雾算法以及改进去雾算法 ,以及 4种 评价去雾效果的 客观指标 。 引言 去雾系统新增功能 结果分析 源码获取 展望 参考文献 在作者前面写过的文章中,已经介绍过图像去雾算法的应用价值及研究现状,并且也介绍了4种去雾算法的

    2024年01月23日
    浏览(79)
  • IP地址怎么实现HTTPS访问?

    为IP地址申请HTTPS证书跟为域名申请证书的步骤大同小异。需要注意的是,IP地址证书申请安装后, 用户只有通过IP地址访问才会显示访问安全 ,如果通过该IP地址所绑定的域名访问的话是没办法显示安全绿锁的。因为之前有用户问过小编——-我的ip地址绑定多个域名,我直接

    2024年01月17日
    浏览(52)
  • 前端问题:如何使网页中的http地址自动升级为https地址

    我一个搞后端开发的天天捣鼓前端的事,会不会被各位同僚念叨,哈哈。项目上的需求,需要把现在的https地址转换成http的地址,然而修改了nginx配置,摘除了证书,访问的所有静态文件依然是https,捣鼓了好半天,终于搞明白了原理,反其道而行,就有了这篇文章。 当我们

    2024年02月14日
    浏览(41)
  • notepad++官网地址 https://notepad-plus-plus.org/;notepad++ 官网地址 https://notepad-plus-plus.org/

    notepad++ 官网地址 https://notepad-plus-plus.org/ 今天想进官网下载notepad++ ,却发现百度搜索官网都是出来很多乱七八糟的,就自己记录一下 notepad++官网:https://notepad-plus-plus.org/ notepad++项目主页:https://github.com/notepad-plus-plus/notepad-plus-plus/

    2024年02月11日
    浏览(42)
  • 如何用JS校验HTTP和HTTPS地址

    当我们需要验证用户输入的网址时,经常需要校验是否为合法的HTTP或HTTPS地址。下面是一些JS代码,可以用来验证HTTP和HTTPS地址。 使用该函数来检查HTTP地址是否有效: 使用该函数来检查HTTPS地址是否有效: 以上JS代码可以很方便地验证HTTP和HTTPS地址的有效性。希望对你有所帮

    2024年02月07日
    浏览(39)
  • TVM编译器推理加速模型

    TVM是一个开源的端到端优化机器学习编译器,目的是加速模型在任意硬件上的计算。 一般情况下如果实在intel的cpu上面部署可能用OpenVino,N卡上面肯定TensorRT,arm架构机器可能会用Ncnn等,意味着要针对每个框架做部署,这里面涉及到的转换非常复杂,部署过的就知道有多少坑

    2024年01月19日
    浏览(55)
  • 《TCP/IP网络编程》阅读笔记--域名及网络地址

    目录 1--域名系统 2--域名与 IP 地址的转换 2-1--利用域名来获取 IP 地址 2-2--利用 IP 地址获取域名 3--代码实例 3-1--gethostbyname() 3-2--gethostbyaddr()         域名系统(Domain Name System, DNS )是对 IP 地址和域名进行相互转换 的系统,其核心是 DNS 服务器;         一般来说, IP

    2024年02月09日
    浏览(65)
  • 简单区分网页地址中http://和https://的区别

    HTTP(HyperText Transfer Protocol)HTTPS(HyperText Transfer Protocol Secure)是两种用于传输数据的协议,它们有以下主要区别: 首先从大的方面来讲: http:超文本传输协议,是一种不安全的协议,对数据不提供任何形式的加密。数据在传输过程中以明文形式发送,容易被中间人窃听和篡

    2024年02月11日
    浏览(35)

觉得文章有用就打赏一下文章作者

支付宝扫一扫打赏

博客赞助

微信扫一扫打赏

请作者喝杯咖啡吧~博客赞助

支付宝扫一扫领取红包,优惠每天领

二维码1

领取红包

二维码2

领红包