Online Softmax

Softmax公式

Softmax(X)=eXimax(X)j=0neXjmax(X)Softmax(X) = \frac{e^{X_i - max(X)}}{\displaystyle\sum_{j=0}^{n} e^{X_j-max(X)}}

指数上都减掉max(X)max(X)防止溢出

简单的代码就是

1
2
3
4
5
6
7
8
9
10
11
12
vector<float> arr;
float sum = 0;
float max_v = arr[0];
for(size_t i = 0 ;i < arr.size(); i++){
max_v = max(max_v, arr[i]);
}
for(size_t i = 0 ;i < arr.size(); i++){
sum += exp(arr[i] - max_v);
}
for(size_t i = 0 ;i < arr.size(); i++){
arr[i] = exp(arr[i] - max_v) / sum ;
}

这里求max 和求sum的时候需要两次reduce,这是不可接受的。

如何压榨这个代码?

sum 是必须的, max 也是必须的,但是这两个for做的事情可以同时做到,但是当前的max_v不一定就是全局的最大值,需要不断的更新更大的max_v,同时要把前面少减的补回来。

Ximax(X)X_i - max(X) 这个操作是在指数上的,也就是说这里是 eXi÷emax(X)e^{X_i} \div e^{max(X)}

gf=eg=e+fa÷g=a÷e÷f\begin{align} g-f &= &e \nonumber \\ g &= &e + &f \nonumber \\ a \div g &= a \div &e \div &f \nonumber \end{align}

所以在for循环XRN,i=nX \in \R^N, i=n时,当新值需要更新max_v值时,直接按新的max算,并把前面已经加好的值除去新老max的差值

if n+1maxnn+1 \neq max_n

i=0n+1eXi=i=0neXi÷e(Xn+1maxn)+eXn+1max(Xn+1,maxn)=i=0neXi×e(maxnXn+1)+1\begin{align} \sum_{i=0}^{n+1} e^{X_i} &= \sum_{i=0}^n e^{X_i}\div e^{(X_{n+1} - max_n)} + e^{X_{n+1} - max(X_{n+1},max_n)} \nonumber \\ &= \sum_{i=0}^n e^{X_i} \times e^{(max_n - X_{n+1})} + 1 \end{align}

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
vector<float> arr;
float sum = 0;
float max_v = arr[0];
for(size_t i = 0 ;i < arr.size(); i++){
if(arr[i] > max_v){
sum *= exp(max_v - arr[i]);
sum += 1;
max_v = arr[i];
}
else{
sum += exp(arr[i] - max_v);
}
}

for(size_t i = 0 ;i < arr.size(); i++){
arr[i] = exp(arr[i] - max_v) / sum ;
}

这样,可以一边求max一边算sum,而且元素直接互不干扰,只要把少的max_v补上就行了,在GPU上分块求和也不影响,block内算自己的,最后block间把差的max_v补齐就是了。

Flash Attention

切开算,基于上文online softmax 的方法尽可能的分块算,然后在最后对齐结果。

Attention 计算公式:

XX是输入,NN是sequence length,dd是hidden size

Q=XWQRN×dK=XWKRN×dV=XWVRN×d\begin{align} Q &= XW_Q \in \R^{N \times d} \\ K &= XW_K \in \R^{N \times d} \\ V &= XW_V \in \R^{N \times d} \\ \end{align}

从这开始:Attention输入为Q,K,VQ,K,V三矩阵,输出OO矩阵

S=Softmax(QKT)dRN×NO=SVRN×d\begin{align} S &= \frac{Softmax(QK^T)}{\sqrt{d}} & \in \R^{N\times N} \\ O &= SV & \in \R^{N \times d} \end{align}

算法过程

Q,K,VQ,K,V都切成B=M4dB = \lceil \frac{M}{4d} \rceil 大小,MM是SRAM的size,比如8K。4d4d是qkvo的空间,这样切目的是把sram用满。
好了,QKVO都切成了T=4NdMT=\lceil \frac{4Nd}{M} \rceil 个块。(假设都刚刚好整除)
B是块大小,T是块个数

两层循环:外层for K,V,用j,内层for Q用i

for j in Tfor i in TSij=QiKjTRB×B\text{for } j \text{ in } T \\ \text{for } i \text{ in } T \\ S_{ij} = Q_iK_j^T \in \R^{B\times B} \\

用前面softmax中求分母的方法得到sumlijl_{ij}和最大值mijm_{ij}

ij是新的

Pij=eSijmijRB×Bmij=rowmax(Sij)RBlij=sum(Pij)RB\begin{align} P_{ij} &= e^{S_{ij} - m_{ij}} & \in \R^{B\times B } \nonumber \\ m_{ij} &= \text{rowmax}(S_{ij}) & \in \R^B \nonumber \\ l_{ij} &= \text{sum}(P_{ij}) & \in \R^B \nonumber \\ \end{align}

mij ,lijm_{ij}\ ,l_{ij}更新到当前的mi ,lim\rq_i \ ,l\rq_i

i保留着历史值,每次outerloop j把所有i过一遍留着下一轮用

mi=max(mi,mij)RBli=emimili+emijmilijRB×B\begin{align} m\rq_{i} &= \text{max}(m_i,m_ij) & \in \R^B \nonumber \\ l\rq_i &= e^{m_i-m\rq_i}l_i + e^{m_ij-m\rq_i}l_ij & \in \R^{B\times B} \nonumber \\ \end{align}

计算最后的O=SVO = SVOO在z维度累加,因为sum和max在变,所以这里也要在线更新历史结果,补齐max导致的差异

Oi=diag(li)1(diag(li)emimiOi+emijmiPijVj)\begin{align} O_i = \text{diag}(l\rq_i)^{-1}(\text{diag}(l_i)e^{m_i-m\rq_i}O_i + e^{m_{ij} - m\rq_i} P_{ij}V{j} ) \nonumber \end{align}

最后把Oi li miO_i \ l\rq_i \ m\rq_i写回global mem,继续下一个loop i