普通Multi-head attention

attention

MLA(Multi-head Latent Attention) 省略掉scale mask等。

可以看到QKV前都做了一次降维
MLA

具体训练网络:

因为K不能直接加ROPE,所以这里额外加一部分不压缩的,可以加ROPE的维度,拼成一个完整的QK。Q只是降维再升维,马甲脱了再穿还是它。
为什么不能加ROPE看下面推理过程。(训练网络拆解示意图,非最终代码实现)
mla training

推理过程优化:

mla inference
在代码中compressed_kv和未加rope的k k_pe 在一个Linear合并计算然后split。

1
2
3
4
5
6
compressed_kv = self.kv_a_proj_with_mqa(hidden_states)
compressed_kv, k_pe = torch.split(
compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1
)
compressed_kv = self.kv_a_layernorm(compressed_kv)
k_pe = k_pe.view(bsz, q_len, 1, self.qk_rope_head_dim).transpose(1, 2)

K和V的权重位置跑掉了,原本的K V cache 变成compressed latent vector。Q也变了,降维的Q后面乘了Q的权重和K的权重。QK的K变成了compressed latent vector,SV的V也是compressed latent vector,而且最后output还乘了个V的权重

Q=XWQK=CTWKV=CTWVQKT=(XWQ)(CTWK)T=(XWQWKT)CTQ = XW_Q \\ K = C_TW_K \\ V = C_TW_V \\ QK^T = (XW_Q)(C_TW_K)^T = (XW_QW_K^T)C_T \\

这里CC 表示降维后的KV,论文中的CtC_tWKW_K是把降维的K升维的权重。推理过程中把 CC 存入cache, 合并计算 X(WQWKT)X(W_QW_K^T) 。这样和QK是等效的,同理,O=SVO=SV 也变成了 O=S(CWV)=SCWVO = S(CW_V) = SCW_V ,然而直接把C当成K计算,C上没法加ROPE了,所以训练就单独加一点普通的QK用来加ROPE。
这样推理时Cache的东西就只有compressed latent vector和rope(K),相比KV cache很少,但是计算的就多了。decode阶段是访存瓶颈,所以依然有很大收益

1
2
3
4
5
6
7
8
kv_b_proj = self.kv_b_proj.weight.view(self.num_heads, -1, self.kv_lora_rank)
q_absorb = kv_b_proj[:, :self.qk_nope_head_dim, :] # Linear K
out_absorb = kv_b_proj[:, self.qk_nope_head_dim:, :] # Linear V

q_nope = torch.matmul(q_nope, q_absorb) # 吸收 Linear K
attn_weights = (torch.matmul(q_pe, k_pe.mT) +
torch.matmul(q_nope, compressed_kv.unsqueeze(-3).mT)) * self.softmax_scale
# 没有concat,直接相加,数学等价,减少显存

计算QK时(代码里是attn_weights),先计算XWQXW_Q q_nope 后乘 WKTW_K^T torch.matmul(q_nope, q_absorb), 而不是离线算好 WQWKTW_QW_K^T,因为分开计算量更小

Reference