Deepseek 一天一开源 根本学不过来。
fp8 gemm ,使用方法是用JIT,很方便
exclusively supports NVIDIA Hopper tensor cores. 我的40系显卡跑不了了。
参考cutlass但是和cutlass繁重的utils剥离开。只有~300行代码
竟然还对比cutlass SASS发现黑科技。
We observe a performance improvement in the CUTLASS FP8 kernel between NVCC 12.2 and 12.3. By comparing the compiled SASS, we discover that one bit in a series of FADD instructions is flipped in an interleaving pattern. After referencing some open-source CUDA assembler implementations, we identified that this bit controls yield, which may enhance warp-level parallelism (just a guess, yielding the current warp and let other warps work).
数据使用TMA加速并且和计算overlap。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 Barrier *full_barriers[kNumStages],*empty_barriers[kNumStages]; if (threadIdx.x >= kNumMathThreads) { empty_barriers[i]->wait (); full_barriers[i]->arrive () } else { full_barriers[i]->wait (); empty_barriers[i]->arrive () }
矩阵计算用WGMMA,tensor core的FFMA指令(矩阵AB+C),这里tensor描述涉及cute layout , 一种加速矩阵运算的Hierarchy Tensor Layout ,关于layout:这个论文讲了layout历史
这种2D layout方便描述大矩阵中切tile。
一个循环体内最多算4个tile,A i B j , A i + 1 , B j , A i B j + 1 , A i + 1 B j + 1 A_iB_j,A_{i+1},B_j,A_iB{j+1},A_{i+1}B{j+1} A i B j , A i + 1 , B j , A i B j + 1 , A i + 1 B j + 1
1 2 3 4 5 6 7 8 9 10 11 12 13 14 for (int k = 0 ; k < BLOCK_K / WGMMA::K; ++ k) { auto desc_a = make_smem_desc (smem_a[s] + math_wg_idx * WGMMA::M * BLOCK_K + k * WGMMA::K, 1 ); auto desc_b = make_smem_desc (smem_b[s] + k * WGMMA::K, 1 ); WGMMA::wgmma (desc_a, desc_b, accum, k); } #pragma unroll for (int i = 0 ; i < WGMMA::kNumAccum / 4 ; ++ i) { bool predicate = kMustUseUniformedScaleB or i < num_former_iters; final_accum[i * 4 + 0 ] += (predicate ? scale_0_0 : scale_0_1) * accum[i * 4 + 0 ]; final_accum[i * 4 + 1 ] += (predicate ? scale_0_0 : scale_0_1) * accum[i * 4 + 1 ]; final_accum[i * 4 + 2 ] += (predicate ? scale_1_0 : scale_1_1) * accum[i * 4 + 2 ]; final_accum[i * 4 + 3 ] += (predicate ? scale_1_0 : scale_1_1) * accum[i * 4 + 3 ]; }
wgmma cutlass里的,不同规模都有。mma_utils.cuh