Nvidia Tensor Core-WMMA API编程入门

这篇具有很好参考价值的文章主要介绍了Nvidia Tensor Core-WMMA API编程入门。希望对大家有所帮助。如果存在错误或未考虑完全的地方,请大家不吝赐教,您也可以点击"举报违法"按钮提交疑问。

1 WMMA (Warp-level Matrix Multiply Accumulate) API

对于计算能力在7.0及以上的CUDA设备,可以使用CUDA C++ API调用Tensor Core,支持形如D = AB + C的混合精度的矩阵乘运算。
template<typename Use, int m, int n, int k, typename T, typename Layout=void> class fragment;

void load_matrix_sync(fragment<...> &a, const T* mptr, unsigned ldm);
void load_matrix_sync(fragment<...> &a, const T* mptr, unsigned ldm, layout_t layout);
void store_matrix_sync(T* mptr, const fragment<...> &a, unsigned ldm, layout_t layout);
void fill_fragment(fragment<...> &a, const T& v);
void mma_sync(fragment<...> &d, const fragment<...> &a, const fragment<...> &b, const fragment<...> &c, bool satf=false);
  • fragment:Tensor Core数据存储类,支持matrix_a、matrix_b和accumulator
  • load_matrix_sync:Tensor Core数据加载API,支持将矩阵数据从global memory或shared memory加载到fragment
  • store_matrix_sync:Tensor Core结果存储API,支持将计算结果从fragment存储到global memory或shared memory
  • fill_fragment:fragment填充API,支持常数值填充
  • mma_sync:Tensor Core矩阵乘计算API,支持D = AB + C或者C = AB + C

2 示例

以m16n16k16为例,实现HGEMM:C = AB,其中矩阵A(M * K,row major)、B(K * N,col major)和C(M * N,row major)的精度均为FP16。首先我们看如何使用CUDA Core写HGEMM naive算法。

2.1 CUDA Core

按照每个线程计算矩阵C中的一个元素来构建naive kernel,首先确定当前线程处理矩阵C的元素坐标,再遍历K并直接从global memory中加载所需A、B矩阵元素到寄存器参与计算,最后将计算结果从寄存器直接写回矩阵C。所有block计算完成之后即可得到矩阵C。这个例子不能说简单,只能说技术含量不高,不过我们只是为了对比。
#define DIV_CEIL(x, y) (((x) + (y) - 1) / (y))

__global__ void naiveKernel(const half *__restrict__ A, const half *__restrict__ B, half *__restrict__ C, size_t M,
                                size_t N, size_t K) {
    size_t row = threadIdx.x + blockDim.x * blockIdx.x;
    size_t col = threadIdx.y + blockDim.y * blockIdx.y;
    if (row < M && col < N) {
        half tmp = 0.0;
        for (size_t i = 0; i < K; ++i) {
            tmp += A[row * K + i] * B[i + col * K];
        }
        C[row * N + col] = tmp;
    }
}

void hgemmNaive(half *A, half *B, half *C, size_t M, size_t N, size_t K) {
    dim3 block(16, 16);
    dim3 grid(DIV_CEIL(M, block.x), DIV_CEIL(N, block.y));

    naiveKernel<<<grid, block>>>(A, B, C, M, N, K);
}

2.2 Tensor Core

我们再来看如何用WMMA API来构建naive kernel,参考cuda sample。与CUDA Core naive不同的是,WMMA需要按照每个warp处理一个矩阵C的WMMA_M * WMMA_N大小的tile的思路来构建,因为Tensor Core的计算层级是warp级别,计算的矩阵元素也是二维的。接下来,与CUDA Core naive的处理思路一致,首先确定当前warp处理矩阵C的tile坐标,声明计算tilie所需的fragment,再以WMMA_K为步长遍历K并直接从global memory中加载所需A、B矩阵tile到fragment参与计算,最后将计算结果从fragment直接写回矩阵C。所有block计算完成之后即可得到矩阵C。
值得注意的是,load_matrix_sync和store_matrix_sync都是按stride访问矩阵元素。
#include <mma.h>

#define WARP_SIZE 32

#define WMMA_M 16
#define WMMA_N 16
#define WMMA_K 16

using namespace nvcuda;

__global__ void wmmaNaiveKernel(const half *__restrict__ A, const half *__restrict__ B, half *__restrict__ C, size_t M,
                                size_t N, size_t K) {
    size_t warpM = (blockIdx.x * blockDim.x + threadIdx.x) / WARP_SIZE;
    size_t warpN = (blockIdx.y * blockDim.y + threadIdx.y);

    wmma::fragment<wmma::matrix_a, WMMA_M, WMMA_N, WMMA_K, half, wmma::row_major> a_frag;
    wmma::fragment<wmma::matrix_b, WMMA_M, WMMA_N, WMMA_K, half, wmma::col_major> b_frag;
    wmma::fragment<wmma::accumulator, WMMA_M, WMMA_N, WMMA_K, half> c_frag;

    wmma::fill_fragment(c_frag, 0.0f);

    for (size_t i = 0; i < K; i += WMMA_K) {
        size_t aCol = i;
        size_t aRow = warpM * WMMA_M;
        size_t bCol = warpN * WMMA_N;
        size_t bRow = i;

        if (aRow < M && aCol < K && bRow < K && bCol < N) {
            wmma::load_matrix_sync(a_frag, A + aCol + aRow * K, K);
            wmma::load_matrix_sync(b_frag, B + bRow + bCol * K, K);

            wmma::mma_sync(c_frag, a_frag, b_frag, c_frag);
        }
    }

    size_t cCol = warpN * WMMA_N;
    size_t cRow = warpM * WMMA_M;

    if (cRow < M && cCol < N) {
        wmma::store_matrix_sync(C + cCol + cRow * N, c_frag, N, wmma::mem_row_major);
    }
}

void hgemmWmmaNaive(half *A, half *B, half *C, size_t M, size_t N, size_t K) {
    dim3 block(128, 4);
    dim3 grid((M - 1) / (WMMA_M * block.x / WARP_SIZE) + 1, (N - 1) / (WMMA_N * block.y) + 1);

    wmmaNaiveKernel<<<grid, block>>>(A, B, C, M, N, K);
}

2.3 区别

从上述两个naive kernel的代码来看调用CUDA Core和Tensor Core的区别如下:
  • 计算层级:CUDA Core是线程级别,Tensor Core是warp级别
  • 计算维度:CUDA Core是一维逐点计算,Tensor Core是二维逐tile计算
  • 计算依赖:WMMA调用Tensor Core需要借助数据存储类fragment,CUDA Core不需要借助其他

3 底层代码

我们再对上述WMMA naive kernel做进一步探索,看一下它在RTX A6000(sm_86,CUDA 11.3)上对应的PTX和SASS。

3.1 PTX

dump出对应的PTX代码如下,好像不那么简单了。
.visible .entry _Z15wmmaNaiveKernelPK6__halfS1_PS_mmm(
.param .u64 _Z15wmmaNaiveKernelPK6__halfS1_PS_mmm_param_0,
.param .u64 _Z15wmmaNaiveKernelPK6__halfS1_PS_mmm_param_1,
.param .u64 _Z15wmmaNaiveKernelPK6__halfS1_PS_mmm_param_2,
.param .u64 _Z15wmmaNaiveKernelPK6__halfS1_PS_mmm_param_3,
.param .u64 _Z15wmmaNaiveKernelPK6__halfS1_PS_mmm_param_4,
.param .u64 _Z15wmmaNaiveKernelPK6__halfS1_PS_mmm_param_5
)
{
.reg .pred %p<8>;
.reg .b16 %rs<2>;
.reg .f32 %f<2>;
.reg .b32 %r<58>;
.reg .b64 %rd<28>;

ld.param.u64 %rd9, [_Z15wmmaNaiveKernelPK6__halfS1_PS_mmm_param_0];
ld.param.u64 %rd10, [_Z15wmmaNaiveKernelPK6__halfS1_PS_mmm_param_1];
ld.param.u64 %rd11, [_Z15wmmaNaiveKernelPK6__halfS1_PS_mmm_param_2];
ld.param.u64 %rd14, [_Z15wmmaNaiveKernelPK6__halfS1_PS_mmm_param_3];
ld.param.u64 %rd12, [_Z15wmmaNaiveKernelPK6__halfS1_PS_mmm_param_4];
ld.param.u64 %rd13, [_Z15wmmaNaiveKernelPK6__halfS1_PS_mmm_param_5];
mov.u32 %r19, %ntid.x;
mov.u32 %r20, %ctaid.x;
mov.u32 %r21, %tid.x;
mad.lo.s32 %r22, %r20, %r19, %r21;
mov.u32 %r23, %ntid.y;
mov.u32 %r24, %ctaid.y;
mov.u32 %r25, %tid.y;
mad.lo.s32 %r26, %r24, %r23, %r25;
mov.f32 %f1, 0f00000000;

	{ cvt.rn.f16.f32 %rs1, %f1;}

	mov.b32 %r50, {%rs1, %rs1};
mul.wide.u32 %rd1, %r26, 16;
shr.u32 %r27, %r22, 1;
and.b32 %r28, %r27, 2147483632;
cvt.u64.u32 %rd2, %r28;
setp.lt.u64 %p2, %rd2, %rd14;
setp.lt.u64 %p3, %rd1, %rd12;
and.pred %p1, %p2, %p3;
setp.eq.s64 %p4, %rd13, 0;
mov.u32 %r51, %r50;
mov.u32 %r52, %r50;
mov.u32 %r53, %r50;
@%p4 bra $L__BB0_5;

mul.lo.s64 %rd3, %rd2, %rd13;
cvt.u32.u64 %r2, %rd13;
mul.lo.s64 %rd4, %rd1, %rd13;
cvta.to.global.u64 %rd5, %rd10;
cvta.to.global.u64 %rd6, %rd9;
mov.u64 %rd27, 0;
not.pred %p5, %p1;
mov.u32 %r51, %r50;
mov.u32 %r52, %r50;
mov.u32 %r53, %r50;

$L__BB0_2:
@%p5 bra $L__BB0_4;

add.s64 %rd16, %rd27, %rd3;
shl.b64 %rd17, %rd16, 1;
add.s64 %rd18, %rd6, %rd17;
wmma.load.a.sync.aligned.row.m16n16k16.global.f16 {%r29, %r30, %r31, %r32, %r33, %r34, %r35, %r36}, [%rd18], %r2;
add.s64 %rd19, %rd27, %rd4;
shl.b64 %rd20, %rd19, 1;
add.s64 %rd21, %rd5, %rd20;
wmma.load.b.sync.aligned.col.m16n16k16.global.f16 {%r37, %r38, %r39, %r40, %r41, %r42, %r43, %r44}, [%rd21], %r2;
wmma.mma.sync.aligned.row.col.m16n16k16.f16.f16 {%r53, %r52, %r51, %r50}, {%r29, %r30, %r31, %r32, %r33, %r34, %r35, %r36}, {%r37, %r38, %r39, %r40, %r41, %r42, %r43, %r44}, {%r53, %r52, %r51, %r50};

$L__BB0_4:
add.s64 %rd27, %rd27, 16;
setp.lt.u64 %p6, %rd27, %rd13;
@%p6 bra $L__BB0_2;

$L__BB0_5:
not.pred %p7, %p1;
@%p7 bra $L__BB0_7;

mul.lo.s64 %rd22, %rd2, %rd12;
add.s64 %rd23, %rd22, %rd1;
cvta.to.global.u64 %rd24, %rd11;
shl.b64 %rd25, %rd23, 1;
add.s64 %rd26, %rd24, %rd25;
cvt.u32.u64 %r45, %rd12;
wmma.store.d.sync.aligned.row.m16n16k16.global.f16 [%rd26], {%r53, %r52, %r51, %r50}, %r45;

$L__BB0_7:
ret;

}

不过我们主要关注WMMA相关的PTX指令,如下所示。可以看到这里正是Nvidia提供的WMMA PTX指令来调用Tensor Core,所以无论是使用WMMA API编程,还是使用WMMA PTX指令编程,底层差别不会太大。

wmma.load.a.sync.aligned.row.m16n16k16.global.f16
wmma.load.b.sync.aligned.col.m16n16k16.global.f16
wmma.mma.sync.aligned.row.col.m16n16k16.f16.f16
wmma.store.d.sync.aligned.row.m16n16k16.global.f16

3.2 SASS

进一步dump出对应的SASS代码,似乎也不简单。
      IMAD.MOV.U32 R1, RZ, RZ, c[0x0][0x28] 
      S2R R0, SR_CTAID.X 
      ISETP.NE.U32.AND P2, PT, RZ, c[0x0][0x188], PT 
      ULDC.64 UR4, c[0x0][0x118] 
      CS2R R8, SRZ 
      S2R R10, SR_CTAID.Y 
      ISETP.NE.AND.EX P2, PT, RZ, c[0x0][0x18c], PT, P2 
      S2R R5, SR_TID.Y 
      S2R R3, SR_TID.X 
      IMAD R10, R10, c[0x0][0x4], R5 
      IMAD R0, R0, c[0x0][0x0], R3 
      IMAD.WIDE.U32 R10, R10, 0x10, RZ 
      CS2R R2, SRZ 
      SHF.R.U32.HI R0, RZ, 0x1, R0 
      ISETP.GE.U32.AND P0, PT, R10, c[0x0][0x180], PT 
      LOP3.LUT R13, R0, 0x7ffffff0, RZ, 0xc0, !PT 
      ISETP.GE.U32.AND.EX P0, PT, R11, c[0x0][0x184], PT, P0 
      ISETP.LT.U32.AND P1, PT, R13, c[0x0][0x178], PT 
      ISETP.LT.U32.AND.EX P0, PT, RZ, c[0x0][0x17c], !P0, P1 
@!P2  BRA 0x7f1eaefc0160 
      BSSY B0, 0x7f1eaefc0160 
      IMAD.MOV.U32 R0, RZ, RZ, RZ 
      CS2R R8, SRZ 
      IMAD.MOV.U32 R15, RZ, RZ, RZ 
      IMAD.MOV.U32 R2, RZ, RZ, RZ 
      BSSY B1, 0x7f1eaefc0100 
@!P0  BRA 0x7f1eaefc00f0 
      S2R R16, SR_LANEID 
      IMAD R17, R11, c[0x0][0x188], RZ 
      IMAD.MOV.U32 R14, RZ, RZ, R0 
      IMAD.MOV.U32 R23, RZ, RZ, c[0x0][0x188] 
      IMAD.WIDE.U32 R6, R10, c[0x0][0x188], R14 
      SHF.R.U32.HI R12, RZ, 0x1, R23 
      IMAD R17, R10, c[0x0][0x18c], R17 
      LEA R21, P2, R6, c[0x0][0x168], 0x1 
      IMAD.WIDE.U32 R4, R13, c[0x0][0x188], R14 
      IMAD.IADD R7, R7, 0x1, R17 
      IMAD.MOV.U32 R17, RZ, RZ, RZ 
      IMAD R5, R13, c[0x0][0x18c], R5 
      LEA.HI.X R7, R6, c[0x0][0x16c], R7, 0x1, P2 
      SHF.R.U32.HI R19, RZ, 0x2, R16 
      LOP3.LUT R16, R16, 0x3, RZ, 0xc0, !PT 
      IMAD.WIDE.U32 R16, R19, R12, R16 
      LEA R19, P1, R4, c[0x0][0x160], 0x1 
      LEA.HI.X R5, R4, c[0x0][0x164], R5, 0x1, P1 
      LEA R18, P1, R16, R19, 0x2 
      LEA R20, P2, R16, R21, 0x2 
      LEA.HI.X R19, R16, R5, R17, 0x2, P1 
      LEA.HI.X R21, R16, R7, R17, 0x2, P2 
      IMAD.WIDE.U32 R16, R23, 0x10, R18 
      LDG.E R4, [R18.64] 
      IMAD.WIDE.U32 R22, R23, 0x10, R20 
      LDG.E R24, [R20.64] 
      LDG.E R25, [R20.64+0x10] 
      LDG.E R6, [R18.64+0x10] 
      LDG.E R5, [R16.64] 
      LDG.E R7, [R16.64+0x10] 
      LDG.E R26, [R22.64] 
      LDG.E R27, [R22.64+0x10] 
      WARPSYNC 0xffffffff 
      HMMA.16816.F16 R8, R4, R24, R8 
      HMMA.16816.F16 R2, R4, R26, R2 
      NOP 
      BSYNC B1 
      IADD3 R0, P1, R0, 0x10, RZ 
      IMAD.X R15, RZ, RZ, R15, P1 
      ISETP.GE.U32.AND P1, PT, R0, c[0x0][0x188], PT 
      ISETP.GE.U32.AND.EX P1, PT, R15, c[0x0][0x18c], PT, P1 
@!P1  BRA 0x7f1eaefbfe90 
      BSYNC B0 
@!P0  EXIT 
      S2R R4, SR_LANEID 
      IMAD.MOV.U32 R15, RZ, RZ, c[0x0][0x180] 
      WARPSYNC 0xffffffff 
      IMAD.WIDE.U32 R10, R13, c[0x0][0x180], R10 
      SHF.R.U32.HI R15, RZ, 0x1, R15 
      IMAD.MOV.U32 R5, RZ, RZ, RZ 
      LEA R7, P0, R10, c[0x0][0x170], 0x1 
      IMAD R11, R13, c[0x0][0x184], R11 
      LEA.HI.X R11, R10, c[0x0][0x174], R11, 0x1, P0 
      SHF.R.U32.HI R0, RZ, 0x2, R4 
      LOP3.LUT R4, R4, 0x3, RZ, 0xc0, !PT 
      IMAD.WIDE.U32 R4, R0, R15, R4 
      LEA R6, P0, R4, R7, 0x2 
      LEA.HI.X R7, R4, R11, R5, 0x2, P0 
      IMAD.WIDE.U32 R4, R15, 0x20, R6 
      STG.E [R6.64], R8 
      STG.E [R4.64], R9 
      STG.E [R6.64+0x10], R2 
      STG.E [R4.64+0x10], R3 
      EXIT 
      BRA 0x7f1eaefc02b0
      NOP
      NOP
      NOP
      NOP
      NOP
      NOP
      NOP
      NOP
      NOP
      NOP
      NOP
      NOP

我们依然主要关注WMMA相关的SASS指令,如下所示。可以发现WMMA161616在底层是通过两个HMMA16816指令实现,同样地,SASS指令也是Nvidia提供的另一种调用Tensor Core的编程方法。文章来源地址https://www.toymoban.com/news/detail-411627.html

HMMA.16816.F16
Nvidia Tensor Core初探中提到Nvidia提供了四种调用Tensor Core的编程方法,这里提到了三种,还有一种是MMA PTX指令,其中MMA16816 PTX指令底层实现即是HMMA16816指令,后续会在MMA PTX相关文章中提及。

4 其他

4.1 HGEMM优化

学习WMMA API的目标在于调用Tensor Core优化HGEMM,相比于cublas,WMMA的性能究竟如何?
 
 

到了这里,关于Nvidia Tensor Core-WMMA API编程入门的文章就介绍完了。如果您还想了解更多内容,请在右上角搜索TOY模板网以前的文章或继续浏览下面的相关文章,希望大家以后多多支持TOY模板网!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处: 如若内容造成侵权/违法违规/事实不符,请点击违法举报进行投诉反馈,一经查实,立即删除!

领支付宝红包 赞助服务器费用

相关文章

  • 入门Pytorch:对Tensor的操作

    目录 前言 一、创建 list创建 numpy创建 填充创建 初始化 规律变化 指定类型创建 指定数据类型 转换数据类型 二、索引 直接索引 切片 用...表示多个被省略 三、维度变换 view,reshape维度变换 unsqueeze插入维度 squeeze删除维度 repeat复制维度 维度交换 四、广播机制 五、拼接和拆分

    2024年02月16日
    浏览(43)
  • 《Pytorch新手入门》第一节-认识Tensor

    参考《深度学习框架PyTorch:入门与实践_陈云(著)》 Tensor 是 PyTorch 中重要的数据结构,可认为是一个高维数组。它可以是一个数(标量)一维数组(向量)二维数组(阵)或更高的数组。Tensor 和 numpy的ndarrays类似,但Tensor 可以使用GPU加速。 torch.size是tuple对象的子类,因此它支持 tup

    2024年02月06日
    浏览(40)
  • Pytorch入门:Tensor加减乘除矩阵运算

    若张量维数大于2,则对最后两维进行matmul。进行此运算的要求是张量a与b除最后两维外的其他维必须一致:

    2024年02月12日
    浏览(46)
  • pytorch 入门1-tensor 广播 view reshape

    tensor 的四则运算 broadcast 常见的构造Tensor的方法: out:

    2024年02月12日
    浏览(43)
  • 从 X 入门Pytorch——Tensor的索引,切片,拼接,拆分,Reduction操作

    本文参加新星计划人工智能(Pytorch)赛道: https://bbs.csdn.net/topics/613989052 承接上文:自己深度学习环境搭建和免费环境使用+Tensor构造+Tensor基本操作: 从 X 入门深度学习(Pytorch版本) 汇总: Name Out a[i, j, k, …] = a[i][j][k][…] 获取张量a的具体数据 a[start : end : step, start1 : end1 : step1

    2024年02月03日
    浏览(42)
  • 3、flink重要概念(api分层、角色、执行流程、执行图和编程模型)及dataset、datastream详细示例入门和提交任务至on yarn运行

    一、Flink 专栏 Flink 专栏系统介绍某一知识点,并辅以具体的示例进行说明。 1、Flink 部署系列 本部分介绍Flink的部署、配置相关基础内容。 2、Flink基础系列 本部分介绍Flink 的基础部分,比如术语、架构、编程模型、编程指南、基本的datastream api用法、四大基石等内容。 3、

    2024年02月12日
    浏览(44)
  • 【NVIDIA CUDA】2023 CUDA夏令营编程模型(二)

    博主未授权任何人或组织机构转载博主任何原创文章,感谢各位对原创的支持! 博主链接 本人就职于国际知名终端厂商,负责modem芯片研发。 在5G早期负责终端数据业务层、核心网相关的开发工作,目前牵头6G算力网络技术标准研究。 博客内容主要围绕:        5G/6G协议

    2024年02月10日
    浏览(38)
  • Flink(三)flink重要概念(api分层、角色、执行流程、执行图和编程模型)及dataset、datastream详细示例入门和提交任务至on yarn运行

    一、Flink 专栏 Flink 专栏系统介绍某一知识点,并辅以具体的示例进行说明。 1、Flink 部署系列 本部分介绍Flink的部署、配置相关基础内容。 2、Flink基础系列 本部分介绍Flink 的基础部分,比如术语、架构、编程模型、编程指南、基本的datastream api用法、四大基石等内容。 3、

    2024年02月16日
    浏览(46)
  • .net Core API 添加 NLog

    nlog.config program.cs  NuGet packages:NLog、NLog.Web.AspNetCore 

    2024年02月12日
    浏览(37)
  • ASP.NET CORE API 使用Orleans

    快速使用Monimal API 快速集成Orleans 微软官网地址如下: https://learn.microsoft.com/zh-cn/dotnet/orleans/quickstarts/build-your-first-orleans-app?source=recommendationstabs=visual-studio 当然它的存储grain存储采用的是内存级别存储,我缓存了mssql 存储。如果是内存存储使用如下代码就Ok 我采用的是数据库存

    2024年02月06日
    浏览(57)

觉得文章有用就打赏一下文章作者

支付宝扫一扫打赏

博客赞助

微信扫一扫打赏

请作者喝杯咖啡吧~博客赞助

支付宝扫一扫领取红包,优惠每天领

二维码1

领取红包

二维码2

领红包