Multi-head Latent Attention
普通Multi-head attention
MLA(Multi-head Latent Attention) 省略掉scale mask等。
可以看到QKV前都做了一次降维
具体训练网络:
因为K不能直接加ROPE,所以这里额外加一部分不压缩的,可以加ROPE的维度,拼成一个完整的QK。Q只是降维再升维,马甲脱了再穿还是它。
为什么不能加ROPE看下面推理过程。
推理过程:
可以看到,训练中K和V的权重位置跑掉了,原本的K V cache
变成Compressed Tensor
。
Q也变了,降维的Q后面乘了Q的权重和K的权重。QK的K变成了Compressed Tensor
,SV的V也是Compressed Tensor
,而且最后output还乘了个V的权重
表示降维后的K(也当V),是把降维的K升维的权重。
推理过程中把 存入cache, 合并计算 。这样和QK是等效的,同理, 也变成了
然而直接把C当成K计算,C上没法加ROPE了,所以训练就单独加一点普通的QK用来加ROPE。
这样推理时Cache的东西就只有Compressed Tensor
,相比KV cache很少,但是计算的就多了。decode阶段是访存瓶颈,所以依然有很大收益
Reference
本博客所有文章除特别声明外,均采用 CC BY-NC-SA 4.0 许可协议。转载请注明来自 JMY Space!