【大模型LLM學習】Flash-Attention的學習記錄
- 0. 前言
- 1. flash-attention原理簡述
- 2. 從softmax到online softmax
- 2.1 safe-softmax
- 2.2 3-pass safe softmax
- 2.3 Online softmax
- 2.4 Flash-attention
- 2.5 Flash-attention tiling
0. 前言
??Flash Attention可以節約模型訓練和推理時間,很多模型可以通過config參數來選擇attention是標準的attention實現還是flash-attention方式。在這里記錄一下flash attention的學習過程,發現了一位博主以及參考的資料特別好:
- zhihu一位做高性能計算的博主博文
- 華盛頓大學的課程note
1. flash-attention原理簡述
a t t e n t i o n ( Q , K , V ) = s o f t m a x ( Q K T d k ) V attention(Q,K,V)=softmax\left(\frac{QK^T}{\sqrt{d_k}}\right)V attention(Q,K,V)=softmax(dk??QKT?)V
??標準的attention操作的時間卡點不是在運算上,而是卡在數據讀寫上。SRAM的讀寫速度快,但是存儲空間有限,無法一次存下來所有的中間計算結果,一次attention計算存在SRAM<->HBM的多次讀寫操作。
??與標準的attention操作比較,flash-attention通過減少數據在HBM和SRAM間的讀寫操作,來節約時間(甚至backward時還進行了重新計算,重新計算的速度也比把數據從HBM讀取到SRAM要快)。
2. 從softmax到online softmax
??直接看flash-attention的論文比較難看明白,發現華盛頓大學的那份note寫得特別清晰,跟著它從softmax看到flash-attention會比較容易。
2.1 safe-softmax
??首先是safe的softmax計算方式。原始的softmax,對于N個數:
s o f t m a x ( { x 1 , . . . , x N } ) = { e x i ∑ j = 1 N e x j } i = 1 N softmax(\{x_1,...,x_N\})=\left\{\frac{e^{x_i}}{\sum_{j=1}^{N}e^{x_j}}\right\}_{i=1}^{N} softmax({x1?,...,xN?})={∑j=1N?exj?exi??}i=1N?
??對于FP16,最大能表示的數據為65536,當 x > = 11 x>=11 x>=11時, e x e^x ex就會超過FP16的最大表示范圍影響結果的正確性。為了避免這個問題,SafeSoftmax 通過減去輸入向量中的最大值來調整輸入,使得最大的指數項變為 e 0 = 1 e^0=1 e0=1從而防止了上溢的發生。同時,由于所有的指數項都除以同一個數,它們的比例關系不會改變,因此也不會影響最終的概率分布。
e x i ∑ j = 1 N e x j = e x i ? m ∑ j = 1 N e x j ? m , m = m a x { x j } j = 1 N \frac{e^{x_i}}{\sum_{j=1}{N}e^{x_j}}=\frac{e^{x_i-m}}{\sum_{j=1}{N}e^{x_j-m}}, \quad m=max\left\{x_j\right\}_{j=1}^{N} ∑j=1?Nexj?exi??=∑j=1?Nexj??mexi??m?,m=max{xj?}j=1N?
2.2 3-pass safe softmax
- 對于一個行向量 { x i } i = 1 N \{x_i\}_{i=1}^N {xi?}i=1N?,最直白的softmax計算方式是直接for循環
??這個算法計算softmax需要執行3次從1->N的循環,在attention中, { x i } \{x_i\} {xi?}是 Q K T QK^T QKT的結果,但是如果SRAM里面存不下這個大的矩陣,上面的計算過程,就需要從HBM里面加載3次 { x i } \{x_i\} {xi?},時間花在了數據讀寫上。
2.3 Online softmax
??如果能把上面(7)(8)(9)這3個式子的計算放一個for循環,就只需要一次load數據。但是 m N m_N mN?是全局最大值,計算 m N m_N mN?就已經需要一次遍歷了。
??Online softmax算法把(7)(8)進行了合并,把3次遍歷縮減為2個。它提出計算 d i ′ = ∑ j = 1 i e x j ? m i d_i^{\prime}=\sum_{j=1}^{i}e^{x_j-m_i} di′?=∑j=1i?exj??mi?來代替計算 d i d_i di?,當算到最后 i = N i=N i=N時會發現, d N = d N ′ d_N=d_N^{\prime} dN?=dN′?。具體的,迭代計算 d i ′ d_i^{\prime} di′?的方式為:
d i ′ = ∑ j = 1 i e x j ? m i = ( ∑ j = 1 i ? 1 e x j ? m i ) + e x i ? m i = ( ∑ j = 1 i ? 1 e x j ? m i ? 1 ) e m i ? 1 ? m i + e x i ? m i = d i ? 1 ′ e m i ? 1 ? m i + e x i ? m i \begin{aligned} d_i^{\prime} &= \sum_{j=1}^{i} e^{x_j - m_i} \\ &= \left( \sum_{j=1}^{i-1} e^{x_j - m_i} \right) + e^{x_i - m_i} \\ &= \left( \sum_{j=1}^{i-1} e^{x_j - m_{i-1}} \right) e^{m_{i-1} - m_i} + e^{x_i - m_i} \\ &= d_{i-1}^{\prime} e^{m_{i-1} - m_i} + e^{x_i - m_i} \end{aligned} di′??=j=1∑i?exj??mi?=(j=1∑i?1?exj??mi?)+exi??mi?=(j=1∑i?1?exj??mi?1?)emi?1??mi?+exi??mi?=di?1′?emi?1??mi?+exi??mi??
??所以就可以用迭代的方式,在找最大值 m N m_N mN?的時候,同時來計算 d i ′ d_i^{\prime} di′?,把(7)和(8)一起計算,這樣只需要加載兩次 x i x_i xi?。
2.4 Flash-attention
??上面的online softmax仍然需要2個for循環,加載2次 x i x_i xi?來完成softmax的計算。完成softmax的計算,沒法更進一步地壓縮到1次遍歷。但是attention計算的最終目標是獲取輸出結果,也就是注意力分數與 V V V相乘的結果 O = A × V O=A \times V O=A×V,計算 O O O可以通過一次遍歷完成。
??可以使用類似online softmax把計算 d i d_i di?變成計算 d i ′ d_i^{\prime} di′?的方式,把 o i o_i oi?的計算也改成迭代式的,首先把 a i a_i ai?帶入 o i o_i oi?的表達式
o i = ∑ j = 1 i ( e x j ? m N d N ′ V [ j , : ] ) o_i=\sum_{j=1}^{i}\left(\frac{e^{x_j-m_{N}}}{d_N^{\prime}}V[j,:]\right) oi?=j=1∑i?(dN′?exj??mN??V[j,:])
??可以找到一個 o i ′ o_i^{\prime} oi′?,它不依賴于全局的 d N ′ d_N^{\prime} dN′?和 m N m_N mN?
o i ′ = ∑ j = 1 i ( e x j ? m i d i ′ V [ j , : ] ) o_i^{\prime}=\sum_{j=1}^{i}\left(\frac{e^{x_j-m_{i}}}{d_i^{\prime}}V[j,:]\right) oi′?=j=1∑i?(di′?exj??mi??V[j,:])
??對于 o i ′ o_i^{\prime} oi′?的計算可以使用迭代的方式,同樣的是有 o N = o N ′ o_N=o_N^{\prime} oN?=oN′?
o i ′ = ∑ j = 1 i e x j ? m i d i ′ V [ j , : ] = ( ∑ j = 1 i ? 1 e x j ? m i d i ′ V [ j , : ] ) + e x i ? m i d i ′ V [ i , : ] = ( ∑ j = 1 i ? 1 e x j ? m i ? 1 d i ? 1 ′ e x j ? m i e x j ? m i ? 1 d i ? 1 ′ d i ′ V [ j , : ] ) + e x i ? m i d i ′ V [ i , : ] = ( ∑ j = 1 i ? 1 e x j ? m i ? 1 d i ? 1 ′ V [ j , : ] ) d i ? 1 ′ d i ′ e m i ? 1 ? m i + e x i ? m i d i ′ V [ i , : ] = o i ? 1 ′ d i ? 1 ′ e m i ? 1 ? m i d i ′ + e x i ? m i d i ′ V [ i , : ] \begin{aligned} o_i' &= \sum_{j=1}^{i} \frac{e^{x_j - m_i}}{d_i'} V[j,:] \\ &= \left( \sum_{j=1}^{i-1} \frac{e^{x_j - m_i}}{d_i'} V[j,:] \right) + \frac{e^{x_i - m_i}}{d_i'} V[i,:] \\ &= \left( \sum_{j=1}^{i-1} \frac{e^{x_j - m_{i-1}}}{d_{i-1}'} \frac{e^{x_j - m_i}}{e^{x_j - m_{i-1}}} \frac{d_{i-1}'}{d_i'} V[j,:] \right) + \frac{e^{x_i - m_i}}{d_i'} V[i,:] \\ &= \left( \sum_{j=1}^{i-1} \frac{e^{x_j - m_{i-1}}}{d_{i-1}'} V[j,:] \right) \frac{d_{i-1}'}{d_i'} e^{m_{i-1} - m_i} + \frac{e^{x_i - m_i}}{d_i'} V[i,:] \\ &= o_{i-1}' \frac{d_{i-1}' e^{m_{i-1} - m_i}}{d_i'} + \frac{e^{x_i - m_i}}{d_i'} V[i,:] \end{aligned} oi′??=j=1∑i?di′?exj??mi??V[j,:]=(j=1∑i?1?di′?exj??mi??V[j,:])+di′?exi??mi??V[i,:]=(j=1∑i?1?di?1′?exj??mi?1??exj??mi?1?exj??mi??di′?di?1′??V[j,:])+di′?exi??mi??V[i,:]=(j=1∑i?1?di?1′?exj??mi?1??V[j,:])di′?di?1′??emi?1??mi?+di′?exi??mi??V[i,:]=oi?1′?di′?di?1′?emi?1??mi??+di′?exi??mi??V[i,:]?
??這樣計算attention的輸出結果可以只進行一次遍歷就完成
2.5 Flash-attention tiling
??上面是每次計算一個元素 [ i ] [i] [i],實際上可以一次讀取一個大小為b的塊(tile)來計算
??此外,在flash-attention的paper里面,對 Q Q Q、 K K K、 V V V和 O O O分塊,其中 Q Q Q
和 O O O每塊大小為 m i n ( M / 4 d , d ) × d min(M/4d,d) \times d min(M/4d,d)×d, K / V K/V K/V的每塊大小為 M / 4 d × d M/4d \times d M/4d×d,加起來正好不會超過SRAM的大小M,完整的算法在paper中: