普通的Attention

Tensor Parallel Attention

TP size = 4

Tensor Parallel, TP size = 4

Linear_QKVout_channel变成dim / tp_size, Linear_outin_channel变成dim / tp_sizeout_channel还是hidden_size,最后输出的shape是[seq_len, hidden_size], 因为Linear_out 切的是“累加维度”,所以需要All-Reduce加起来是完整的output

Sequence Parallel

Sequence Parallel

All-Reduce = Reduce-Scatter + All-Gather 。 通信量一样

seq length×hidden size4×3×2\frac{seq\ length \times hidden\ size}{4} \times 3 \times 2

Reduce-Scatter 后,每个卡都有部分sequce,但是已经累加好了,此时可以进行sequence并行计算add, layernorm等。算好后再All-Gather,使每个卡都持有完整的结果,进入下一layer的attention或者MLP

MLP的Tensor parallel 的输入输出和Attention一样,都是完整的输入,输出需要累加

MLP Tensor Parallel