Deepseek 一天一开源 根本学不过来。

fp8 gemm ,使用方法是用JIT,很方便

exclusively supports NVIDIA Hopper tensor cores. 我的40系显卡跑不了了。

参考cutlass但是和cutlass繁重的utils剥离开。只有~300行代码

竟然还对比cutlass SASS发现黑科技。

We observe a performance improvement in the CUTLASS FP8 kernel between NVCC 12.2 and 12.3. By comparing the compiled SASS, we discover that one bit in a series of FADD instructions is flipped in an interleaving pattern. After referencing some open-source CUDA assembler implementations, we identified that this bit controls yield, which may enhance warp-level parallelism (just a guess, yielding the current warp and let other warps work).

数据使用TMA加速并且和计算overlap。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
// 准备好barrier, full表示有数据可以计算,empty表示算完没用了换下一个
Barrier *full_barriers[kNumStages],*empty_barriers[kNumStages];
if (threadIdx.x >= kNumMathThreads) {
// 这些thread 负责 copy
empty_barriers[i]->wait();
// copy
full_barriers[i]->arrive()

} else {
// 这些thread 负责 calc
full_barriers[i]->wait();
// calc
empty_barriers[i]->arrive()
}

矩阵计算用WGMMA,tensor core的FFMA指令(矩阵AB+C),这里tensor描述涉及cute layout, 一种加速矩阵运算的Hierarchy Tensor Layout,关于layout:这个论文讲了layout历史
这种2D layout方便描述大矩阵中切tile。

一个循环体内最多算4个tile,AiBj,Ai+1,Bj,AiBj+1,Ai+1Bj+1A_iB_j,A_{i+1},B_j,A_iB{j+1},A_{i+1}B{j+1}

1
2
3
4
5
6
7
8
9
10
11
12
13
14
for (int k = 0; k < BLOCK_K / WGMMA::K; ++ k) {
auto desc_a = make_smem_desc(smem_a[s] + math_wg_idx * WGMMA::M * BLOCK_K + k * WGMMA::K, 1);
auto desc_b = make_smem_desc(smem_b[s] + k * WGMMA::K, 1);
WGMMA::wgmma(desc_a, desc_b, accum, k);
}
// do something
#pragma unroll
for (int i = 0; i < WGMMA::kNumAccum / 4; ++ i) {
bool predicate = kMustUseUniformedScaleB or i < num_former_iters;
final_accum[i * 4 + 0] += (predicate ? scale_0_0 : scale_0_1) * accum[i * 4 + 0];
final_accum[i * 4 + 1] += (predicate ? scale_0_0 : scale_0_1) * accum[i * 4 + 1];
final_accum[i * 4 + 2] += (predicate ? scale_1_0 : scale_1_1) * accum[i * 4 + 2];
final_accum[i * 4 + 3] += (predicate ? scale_1_0 : scale_1_1) * accum[i * 4 + 3];
}

wgmma实现是直接从cutlass里扒的指令,不同规模都有。mma_utils.cuh