Multi-head Latent Attention
普通Multi-head attention
MLA(Multi-head Latent Attention) 省略掉scale mask等。
可以看到QKV前都做了一次降维
具体训练网络:
因为K不能直接加ROPE,所以这里额外加一部分不压缩的,可以加ROPE的维度,拼成一个完整的QK。Q只是降维再升维,马甲脱了再穿还是它。
为什么不能加ROPE看下面推理过程。(训练网络拆解示意图,非最终代码实现)
推理过程优化:
在代码中compressed_kv
和未加rope的k k_pe
在一个Linear
合并计算然后split。
1 | compressed_kv = self.kv_a_proj_with_mqa(hidden_states) |
K和V的权重位置跑掉了,原本的K V cache
变成compressed latent vector
。Q也变了,降维的Q后面乘了Q的权重和K的权重。QK的K变成了compressed latent vector
,SV的V也是compressed latent vector
,而且最后output还乘了个V的权重
这里 表示降维后的KV,论文中的,是把降维的K升维的权重。推理过程中把 存入cache, 合并计算 。这样和QK是等效的,同理, 也变成了 ,然而直接把C当成K计算,C上没法加ROPE了,所以训练就单独加一点普通的QK用来加ROPE。
这样推理时Cache的东西就只有compressed latent vector
和rope(K),相比KV cache很少,但是计算的就多了。decode阶段是访存瓶颈,所以依然有很大收益
1 | kv_b_proj = self.kv_b_proj.weight.view(self.num_heads, -1, self.kv_lora_rank) |
计算QK时(代码里是attn_weights
),先计算 q_nope
后乘 torch.matmul(q_nope, q_absorb)
, 而不是离线算好 ,因为分开计算量更小
Reference
本博客所有文章除特别声明外,均采用 CC BY-NC-SA 4.0 许可协议。转载请注明来自 JMY Space!