• Batch Size BB
  • Sequence length SS
  • Head num HH
  • Head dim dd
  • Hidden size h=H×dh = H\times d
  • Parallel Size pp

Tensor Parallel in one machine

Attention

计算量

Attention
Usually d = hidden_size
矩阵乘法每个元素为乘加计算,为2个操作
XWqkvXW_{qkv}:

2×B×3×s×h×h=6Bsh22 \times B \times 3 \times s \times h \times h = 6Bsh^2

QKTQK^T:

2×B×s×h×s=2Bs2h2 \times B \times s \times h \times s = 2Bs^2h

P=SVP = SV:

2×B×s×s×h=2Bs2h2 \times B \times s \times s \times h = 2Bs^2h

output=PWooutput = PW_o:

2×B×s×h×h=2Bsh22 \times B \times s \times h \times h = 2Bsh^2

total:

FLOP=8Bsh2+4Bs2h\text{FLOP} = 8Bsh^2 + 4Bs^2h

通信量:
All-Reduce after attention layer:

2×(p1)×1pBSHd=2(p1)BSHdp2 \times (p-1) \times \frac{1}{p} BSHd = \frac{2(p-1)BSHd}{p}

Prefill阶段Sequence Parallel 是 Reduce-Scatter + All Gather, 通信量和All-Reduce一样

Sequence Parallel

MLP Block

前后通信方式和Attention一样,故通信量也一样。

MLP Tensor Parallel
计算量:
升维再降维的操作,设升维到dd

2×B×shd=2Bshd2 \times B \times shd = 2Bshd

升维再降维总量:

FLOP=4Bshd\text{FLOP}=4Bshd

FFN in MOE

前后各一个All-To-All
收集Attention结果,得到完整的sequence作为FFN输入,FFN输出分到各rank继续其他层TP执行

2×(p1)×1pBSHd=2(p1)BSHdp2 \times (p-1) \times \frac{1}{p} BSHd = \frac{2(p-1)BSHd}{p}

Sequence Parallel

计算量:
MOE内部FFN为降维再升维,设降维到dd

FLOP=4Bshd\text{FLOP}=4Bshd

Pipeline Parallel

按层分到各个rank,每个rank负责一部分网络
单次rank间通信为BSHdBSHd,整体总通信

(p1)×BSHd(p-1) \times BSHd

通信量减少,服务端利用率最大化。
每个Request一次在各rank上运行并把输出传到下个rank,相比Tensor Parallel latency增加。

TP + PP 如 TPsize=4TP size = 4, PPsize=2PP size =2, 在attention后的All-Reduce可以从 P0P_0传到 P1P_1,一次的通信量是p×BSHdp=BSHdp\times \frac{BSHd}{p} = BSHd

KV Cache

  • Layer num LL

decoder需要获取之前每个step token 的KV,s=sprefill+sdecodes\rq = s_{prefill}+s_{decode}

KV cache size:

2×L×Hd×s2 \times L \times Hd \times s\rq

再乘上sizeof(dtype)即可

计算量:
每个step需要计算当前token 的QKV,即s=1s = 1
XWqkvXW_{qkv}

FLOP=2×B×3×1×h×h=6Bh2\text{FLOP} = 2 \times B \times 3 \times 1 \times h \times h = 6Bh^2

QiKTQ_iK^T为当前token QiQ_i和所有KK

FLOP=2×B×1×h×s=2Bsh\text{FLOP} = 2 \times B \times 1 \times h \times s = 2Bsh

P=SVP = SV:

FLOP=2×B×1×s×h=2Bsh\text{FLOP} = 2 \times B \times 1 \times s \times h = 2Bsh

output=PWooutput = PW_o:

FLOP=2×B×1×h×h=2Bh2\text{FLOP} =2 \times B \times 1 \times h \times h = 2Bh^2

总计算量:

FLOP=8Bh2+4Bsh\text{FLOP} = 8Bh^2 + 4Bsh