普通Multi-head attention

attention

MLA(Multi-head Latent Attention) 省略掉scale mask等。

可以看到QKV前都做了一次降维
MLA

具体训练网络:

因为K不能直接加ROPE,所以这里额外加一部分不压缩的,可以加ROPE的维度,拼成一个完整的QK。Q只是降维再升维,马甲脱了再穿还是它。
为什么不能加ROPE看下面推理过程。
mla training

推理过程:

mla inference
可以看到,训练中K和V的权重位置跑掉了,原本的K V cache 变成Compressed Tensor
Q也变了,降维的Q后面乘了Q的权重和K的权重。QK的K变成了Compressed Tensor,SV的V也是Compressed Tensor,而且最后output还乘了个V的权重

Q=XWQK=CWKV=CWVQKT=(XWQ)(CKWK)T=X(WQWKT)CQ = XW_Q \\ K = CW_K \\ V = CW_V \\ QK^T = (XW_Q)(C_KW_K)^T = X(W_QW_K^T)C \\

CC表示降维后的K(也当V),WKW_K是把降维的K升维的权重。
推理过程中把CC 存入cache, 合并计算 X(WQWKT)X(W_QW_K^T) 。这样和QK是等效的,同理,O=SVO=SV 也变成了 O=S(CWV)=SCWVO = S(CW_V) = SCW_V

然而直接把C当成K计算,C上没法加ROPE了,所以训练就单独加一点普通的QK用来加ROPE。

这样推理时Cache的东西就只有Compressed Tensor,相比KV cache很少,但是计算的就多了。decode阶段是访存瓶颈,所以依然有很大收益

Reference