1 WMMA (Warp-level Matrix Multiply Accumulate) API
template<typename Use, int m, int n, int k, typename T, typename Layout=void> class fragment; void load_matrix_sync(fragment<...> &a, const T* mptr, unsigned ldm); void load_matrix_sync(fragment<...> &a, const T* mptr, unsigned ldm, layout_t layout); void store_matrix_sync(T* mptr, const fragment<...> &a, unsigned ldm, layout_t layout); void fill_fragment(fragment<...> &a, const T& v); void mma_sync(fragment<...> &d, const fragment<...> &a, const fragment<...> &b, const fragment<...> &c, bool satf=false);
-
fragment:Tensor Core数据存储类,支持matrix_a、matrix_b和accumulator
-
load_matrix_sync:Tensor Core数据加载API,支持将矩阵数据从global memory或shared memory加载到fragment
-
store_matrix_sync:Tensor Core结果存储API,支持将计算结果从fragment存储到global memory或shared memory
-
fill_fragment:fragment填充API,支持常数值填充
-
mma_sync:Tensor Core矩阵乘计算API,支持D = AB + C或者C = AB + C
2 示例
2.1 CUDA Core
#define DIV_CEIL(x, y) (((x) + (y) - 1) / (y)) __global__ void naiveKernel(const half *__restrict__ A, const half *__restrict__ B, half *__restrict__ C, size_t M, size_t N, size_t K) { size_t row = threadIdx.x + blockDim.x * blockIdx.x; size_t col = threadIdx.y + blockDim.y * blockIdx.y; if (row < M && col < N) { half tmp = 0.0; for (size_t i = 0; i < K; ++i) { tmp += A[row * K + i] * B[i + col * K]; } C[row * N + col] = tmp; } } void hgemmNaive(half *A, half *B, half *C, size_t M, size_t N, size_t K) { dim3 block(16, 16); dim3 grid(DIV_CEIL(M, block.x), DIV_CEIL(N, block.y)); naiveKernel<<<grid, block>>>(A, B, C, M, N, K); }
2.2 Tensor Core
#include <mma.h> #define WARP_SIZE 32 #define WMMA_M 16 #define WMMA_N 16 #define WMMA_K 16 using namespace nvcuda; __global__ void wmmaNaiveKernel(const half *__restrict__ A, const half *__restrict__ B, half *__restrict__ C, size_t M, size_t N, size_t K) { size_t warpM = (blockIdx.x * blockDim.x + threadIdx.x) / WARP_SIZE; size_t warpN = (blockIdx.y * blockDim.y + threadIdx.y); wmma::fragment<wmma::matrix_a, WMMA_M, WMMA_N, WMMA_K, half, wmma::row_major> a_frag; wmma::fragment<wmma::matrix_b, WMMA_M, WMMA_N, WMMA_K, half, wmma::col_major> b_frag; wmma::fragment<wmma::accumulator, WMMA_M, WMMA_N, WMMA_K, half> c_frag; wmma::fill_fragment(c_frag, 0.0f); for (size_t i = 0; i < K; i += WMMA_K) { size_t aCol = i; size_t aRow = warpM * WMMA_M; size_t bCol = warpN * WMMA_N; size_t bRow = i; if (aRow < M && aCol < K && bRow < K && bCol < N) { wmma::load_matrix_sync(a_frag, A + aCol + aRow * K, K); wmma::load_matrix_sync(b_frag, B + bRow + bCol * K, K); wmma::mma_sync(c_frag, a_frag, b_frag, c_frag); } } size_t cCol = warpN * WMMA_N; size_t cRow = warpM * WMMA_M; if (cRow < M && cCol < N) { wmma::store_matrix_sync(C + cCol + cRow * N, c_frag, N, wmma::mem_row_major); } } void hgemmWmmaNaive(half *A, half *B, half *C, size_t M, size_t N, size_t K) { dim3 block(128, 4); dim3 grid((M - 1) / (WMMA_M * block.x / WARP_SIZE) + 1, (N - 1) / (WMMA_N * block.y) + 1); wmmaNaiveKernel<<<grid, block>>>(A, B, C, M, N, K); }
2.3 区别
-
计算层级:CUDA Core是线程级别,Tensor Core是warp级别
-
计算维度:CUDA Core是一维逐点计算,Tensor Core是二维逐tile计算
-
计算依赖:WMMA调用Tensor Core需要借助数据存储类fragment,CUDA Core不需要借助其他
3 底层代码
3.1 PTX
.visible .entry _Z15wmmaNaiveKernelPK6__halfS1_PS_mmm( .param .u64 _Z15wmmaNaiveKernelPK6__halfS1_PS_mmm_param_0, .param .u64 _Z15wmmaNaiveKernelPK6__halfS1_PS_mmm_param_1, .param .u64 _Z15wmmaNaiveKernelPK6__halfS1_PS_mmm_param_2, .param .u64 _Z15wmmaNaiveKernelPK6__halfS1_PS_mmm_param_3, .param .u64 _Z15wmmaNaiveKernelPK6__halfS1_PS_mmm_param_4, .param .u64 _Z15wmmaNaiveKernelPK6__halfS1_PS_mmm_param_5 ) { .reg .pred %p<8>; .reg .b16 %rs<2>; .reg .f32 %f<2>; .reg .b32 %r<58>; .reg .b64 %rd<28>; ld.param.u64 %rd9, [_Z15wmmaNaiveKernelPK6__halfS1_PS_mmm_param_0]; ld.param.u64 %rd10, [_Z15wmmaNaiveKernelPK6__halfS1_PS_mmm_param_1]; ld.param.u64 %rd11, [_Z15wmmaNaiveKernelPK6__halfS1_PS_mmm_param_2]; ld.param.u64 %rd14, [_Z15wmmaNaiveKernelPK6__halfS1_PS_mmm_param_3]; ld.param.u64 %rd12, [_Z15wmmaNaiveKernelPK6__halfS1_PS_mmm_param_4]; ld.param.u64 %rd13, [_Z15wmmaNaiveKernelPK6__halfS1_PS_mmm_param_5]; mov.u32 %r19, %ntid.x; mov.u32 %r20, %ctaid.x; mov.u32 %r21, %tid.x; mad.lo.s32 %r22, %r20, %r19, %r21; mov.u32 %r23, %ntid.y; mov.u32 %r24, %ctaid.y; mov.u32 %r25, %tid.y; mad.lo.s32 %r26, %r24, %r23, %r25; mov.f32 %f1, 0f00000000; { cvt.rn.f16.f32 %rs1, %f1;} mov.b32 %r50, {%rs1, %rs1}; mul.wide.u32 %rd1, %r26, 16; shr.u32 %r27, %r22, 1; and.b32 %r28, %r27, 2147483632; cvt.u64.u32 %rd2, %r28; setp.lt.u64 %p2, %rd2, %rd14; setp.lt.u64 %p3, %rd1, %rd12; and.pred %p1, %p2, %p3; setp.eq.s64 %p4, %rd13, 0; mov.u32 %r51, %r50; mov.u32 %r52, %r50; mov.u32 %r53, %r50; @%p4 bra $L__BB0_5; mul.lo.s64 %rd3, %rd2, %rd13; cvt.u32.u64 %r2, %rd13; mul.lo.s64 %rd4, %rd1, %rd13; cvta.to.global.u64 %rd5, %rd10; cvta.to.global.u64 %rd6, %rd9; mov.u64 %rd27, 0; not.pred %p5, %p1; mov.u32 %r51, %r50; mov.u32 %r52, %r50; mov.u32 %r53, %r50; $L__BB0_2: @%p5 bra $L__BB0_4; add.s64 %rd16, %rd27, %rd3; shl.b64 %rd17, %rd16, 1; add.s64 %rd18, %rd6, %rd17; wmma.load.a.sync.aligned.row.m16n16k16.global.f16 {%r29, %r30, %r31, %r32, %r33, %r34, %r35, %r36}, [%rd18], %r2; add.s64 %rd19, %rd27, %rd4; shl.b64 %rd20, %rd19, 1; add.s64 %rd21, %rd5, %rd20; wmma.load.b.sync.aligned.col.m16n16k16.global.f16 {%r37, %r38, %r39, %r40, %r41, %r42, %r43, %r44}, [%rd21], %r2; wmma.mma.sync.aligned.row.col.m16n16k16.f16.f16 {%r53, %r52, %r51, %r50}, {%r29, %r30, %r31, %r32, %r33, %r34, %r35, %r36}, {%r37, %r38, %r39, %r40, %r41, %r42, %r43, %r44}, {%r53, %r52, %r51, %r50}; $L__BB0_4: add.s64 %rd27, %rd27, 16; setp.lt.u64 %p6, %rd27, %rd13; @%p6 bra $L__BB0_2; $L__BB0_5: not.pred %p7, %p1; @%p7 bra $L__BB0_7; mul.lo.s64 %rd22, %rd2, %rd12; add.s64 %rd23, %rd22, %rd1; cvta.to.global.u64 %rd24, %rd11; shl.b64 %rd25, %rd23, 1; add.s64 %rd26, %rd24, %rd25; cvt.u32.u64 %r45, %rd12; wmma.store.d.sync.aligned.row.m16n16k16.global.f16 [%rd26], {%r53, %r52, %r51, %r50}, %r45; $L__BB0_7: ret; }
不过我们主要关注WMMA相关的PTX指令,如下所示。可以看到这里正是Nvidia提供的WMMA PTX指令来调用Tensor Core,所以无论是使用WMMA API编程,还是使用WMMA PTX指令编程,底层差别不会太大。文章来源:https://www.toymoban.com/news/detail-411627.html
wmma.load.a.sync.aligned.row.m16n16k16.global.f16 wmma.load.b.sync.aligned.col.m16n16k16.global.f16 wmma.mma.sync.aligned.row.col.m16n16k16.f16.f16 wmma.store.d.sync.aligned.row.m16n16k16.global.f16
3.2 SASS
IMAD.MOV.U32 R1, RZ, RZ, c[0x0][0x28] S2R R0, SR_CTAID.X ISETP.NE.U32.AND P2, PT, RZ, c[0x0][0x188], PT ULDC.64 UR4, c[0x0][0x118] CS2R R8, SRZ S2R R10, SR_CTAID.Y ISETP.NE.AND.EX P2, PT, RZ, c[0x0][0x18c], PT, P2 S2R R5, SR_TID.Y S2R R3, SR_TID.X IMAD R10, R10, c[0x0][0x4], R5 IMAD R0, R0, c[0x0][0x0], R3 IMAD.WIDE.U32 R10, R10, 0x10, RZ CS2R R2, SRZ SHF.R.U32.HI R0, RZ, 0x1, R0 ISETP.GE.U32.AND P0, PT, R10, c[0x0][0x180], PT LOP3.LUT R13, R0, 0x7ffffff0, RZ, 0xc0, !PT ISETP.GE.U32.AND.EX P0, PT, R11, c[0x0][0x184], PT, P0 ISETP.LT.U32.AND P1, PT, R13, c[0x0][0x178], PT ISETP.LT.U32.AND.EX P0, PT, RZ, c[0x0][0x17c], !P0, P1 @!P2 BRA 0x7f1eaefc0160 BSSY B0, 0x7f1eaefc0160 IMAD.MOV.U32 R0, RZ, RZ, RZ CS2R R8, SRZ IMAD.MOV.U32 R15, RZ, RZ, RZ IMAD.MOV.U32 R2, RZ, RZ, RZ BSSY B1, 0x7f1eaefc0100 @!P0 BRA 0x7f1eaefc00f0 S2R R16, SR_LANEID IMAD R17, R11, c[0x0][0x188], RZ IMAD.MOV.U32 R14, RZ, RZ, R0 IMAD.MOV.U32 R23, RZ, RZ, c[0x0][0x188] IMAD.WIDE.U32 R6, R10, c[0x0][0x188], R14 SHF.R.U32.HI R12, RZ, 0x1, R23 IMAD R17, R10, c[0x0][0x18c], R17 LEA R21, P2, R6, c[0x0][0x168], 0x1 IMAD.WIDE.U32 R4, R13, c[0x0][0x188], R14 IMAD.IADD R7, R7, 0x1, R17 IMAD.MOV.U32 R17, RZ, RZ, RZ IMAD R5, R13, c[0x0][0x18c], R5 LEA.HI.X R7, R6, c[0x0][0x16c], R7, 0x1, P2 SHF.R.U32.HI R19, RZ, 0x2, R16 LOP3.LUT R16, R16, 0x3, RZ, 0xc0, !PT IMAD.WIDE.U32 R16, R19, R12, R16 LEA R19, P1, R4, c[0x0][0x160], 0x1 LEA.HI.X R5, R4, c[0x0][0x164], R5, 0x1, P1 LEA R18, P1, R16, R19, 0x2 LEA R20, P2, R16, R21, 0x2 LEA.HI.X R19, R16, R5, R17, 0x2, P1 LEA.HI.X R21, R16, R7, R17, 0x2, P2 IMAD.WIDE.U32 R16, R23, 0x10, R18 LDG.E R4, [R18.64] IMAD.WIDE.U32 R22, R23, 0x10, R20 LDG.E R24, [R20.64] LDG.E R25, [R20.64+0x10] LDG.E R6, [R18.64+0x10] LDG.E R5, [R16.64] LDG.E R7, [R16.64+0x10] LDG.E R26, [R22.64] LDG.E R27, [R22.64+0x10] WARPSYNC 0xffffffff HMMA.16816.F16 R8, R4, R24, R8 HMMA.16816.F16 R2, R4, R26, R2 NOP BSYNC B1 IADD3 R0, P1, R0, 0x10, RZ IMAD.X R15, RZ, RZ, R15, P1 ISETP.GE.U32.AND P1, PT, R0, c[0x0][0x188], PT ISETP.GE.U32.AND.EX P1, PT, R15, c[0x0][0x18c], PT, P1 @!P1 BRA 0x7f1eaefbfe90 BSYNC B0 @!P0 EXIT S2R R4, SR_LANEID IMAD.MOV.U32 R15, RZ, RZ, c[0x0][0x180] WARPSYNC 0xffffffff IMAD.WIDE.U32 R10, R13, c[0x0][0x180], R10 SHF.R.U32.HI R15, RZ, 0x1, R15 IMAD.MOV.U32 R5, RZ, RZ, RZ LEA R7, P0, R10, c[0x0][0x170], 0x1 IMAD R11, R13, c[0x0][0x184], R11 LEA.HI.X R11, R10, c[0x0][0x174], R11, 0x1, P0 SHF.R.U32.HI R0, RZ, 0x2, R4 LOP3.LUT R4, R4, 0x3, RZ, 0xc0, !PT IMAD.WIDE.U32 R4, R0, R15, R4 LEA R6, P0, R4, R7, 0x2 LEA.HI.X R7, R4, R11, R5, 0x2, P0 IMAD.WIDE.U32 R4, R15, 0x20, R6 STG.E [R6.64], R8 STG.E [R4.64], R9 STG.E [R6.64+0x10], R2 STG.E [R4.64+0x10], R3 EXIT BRA 0x7f1eaefc02b0 NOP NOP NOP NOP NOP NOP NOP NOP NOP NOP NOP NOP
我们依然主要关注WMMA相关的SASS指令,如下所示。可以发现WMMA161616在底层是通过两个HMMA16816指令实现,同样地,SASS指令也是Nvidia提供的另一种调用Tensor Core的编程方法。文章来源地址https://www.toymoban.com/news/detail-411627.html
HMMA.16816.F16
4 其他
4.1 HGEMM优化
到了这里,关于Nvidia Tensor Core-WMMA API编程入门的文章就介绍完了。如果您还想了解更多内容,请在右上角搜索TOY模板网以前的文章或继续浏览下面的相关文章,希望大家以后多多支持TOY模板网!