Online Softmax
Softmax公式
Softmax(X)=j=0∑neXj−max(X)eXi−max(X)
指数上都减掉max(X)防止溢出
简单的代码就是
1 2 3 4 5 6 7 8 9 10 11 12
| vector<float> arr; float sum = 0; float max_v = arr[0]; for(size_t i = 0 ;i < arr.size(); i++){ max_v = max(max_v, arr[i]); } for(size_t i = 0 ;i < arr.size(); i++){ sum += exp(arr[i] - max_v); } for(size_t i = 0 ;i < arr.size(); i++){ arr[i] = exp(arr[i] - max_v) / sum ; }
|
这里求max 和求sum的时候需要两次reduce,这是不可接受的。
如何压榨这个代码?
sum 是必须的, max 也是必须的,但是这两个for做的事情可以同时做到,但是当前的max_v
不一定就是全局的最大值,需要不断的更新更大的max_v
,同时要把前面少减的补回来。
Xi−max(X) 这个操作是在指数上的,也就是说这里是 eXi÷emax(X)
g−fga÷g===a÷ee+e÷ff
所以在for循环X∈RN,i=n时,当新值需要更新max_v
值时,直接按新的max算,并把前面已经加好的值除去新老max的差值。
if n+1=maxn
i=0∑n+1eXi=i=0∑neXi÷e(Xn+1−maxn)+eXn+1−max(Xn+1,maxn)=i=0∑neXi×e(maxn−Xn+1)+1
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18
| vector<float> arr; float sum = 0; float max_v = arr[0]; for(size_t i = 0 ;i < arr.size(); i++){ if(arr[i] > max_v){ sum *= exp(max_v - arr[i]); sum += 1; max_v = arr[i]; } else{ sum += exp(arr[i] - max_v); } }
for(size_t i = 0 ;i < arr.size(); i++){ arr[i] = exp(arr[i] - max_v) / sum ; }
|
这样,可以一边求max一边算sum,而且元素直接互不干扰,只要把少的max_v
补上就行了,在GPU上分块求和也不影响,block内算自己的,最后block间把差的max_v
补齐就是了。
Flash Attention
切开算,基于上文online softmax 的方法尽可能的分块算,然后在最后对齐结果。
Attention 计算公式:
X是输入,N是sequence length,d是hidden size
QKV=XWQ∈RN×d=XWK∈RN×d=XWV∈RN×d
从这开始:Attention输入为Q,K,V三矩阵,输出O矩阵
SO=dSoftmax(QKT)=SV∈RN×N∈RN×d
算法过程

将Q,K,V都切成B=⌈4dM⌉ 大小,M是SRAM的size,比如8K。4d是qkvo的空间,这样切目的是把sram用满。
好了,QKVO都切成了T=⌈M4Nd⌉ 个块。(假设都刚刚好整除)
B是块大小,T是块个数
两层循环:外层for K,V,用j
,内层for Q用i
for j in Tfor i in TSij=QiKjT∈RB×B
用前面softmax中求分母的方法得到sumlij和最大值mij
ij是新的
Pijmijlij=eSij−mij=rowmax(Sij)=sum(Pij)∈RB×B∈RB∈RB
把mij ,lij更新到当前的mi′ ,li′
i保留着历史值,每次outerloop j把所有i过一遍留着下一轮用
mi′li′=max(mi,mij)=emi−mi′li+emij−mi′lij∈RB∈RB×B
计算最后的O=SV,O在z维度累加,因为sum和max在变,所以这里也要在线更新历史结果,补齐max导致的差异
Oi=diag(li′)−1(diag(li)emi−mi′Oi+emij−mi′PijVj)
最后把Oi li′ mi′写回global mem,继续下一个loop i