【大模型LLM學習】Flash-Attention的學習記錄

【大模型LLM學習】Flash-Attention的學習記錄

  • 0. 前言
  • 1. flash-attention原理簡述
  • 2. 從softmax到online softmax
    • 2.1 safe-softmax
    • 2.2 3-pass safe softmax
    • 2.3 Online softmax
    • 2.4 Flash-attention
    • 2.5 Flash-attention tiling

0. 前言

??Flash Attention可以節約模型訓練和推理時間,很多模型可以通過config參數來選擇attention是標準的attention實現還是flash-attention方式。在這里記錄一下flash attention的學習過程,發現了一位博主以及參考的資料特別好:

  • zhihu一位做高性能計算的博主博文
  • 華盛頓大學的課程note

1. flash-attention原理簡述

a t t e n t i o n ( Q , K , V ) = s o f t m a x ( Q K T d k ) V attention(Q,K,V)=softmax\left(\frac{QK^T}{\sqrt{d_k}}\right)V attention(Q,K,V)=softmax(dk? ?QKT?)V
??標準的attention操作的時間卡點不是在運算上,而是卡在數據讀寫上。SRAM的讀寫速度快,但是存儲空間有限,無法一次存下來所有的中間計算結果,一次attention計算存在SRAM<->HBM的多次讀寫操作。
在這里插入圖片描述
??與標準的attention操作比較,flash-attention通過減少數據在HBM和SRAM間的讀寫操作,來節約時間(甚至backward時還進行了重新計算,重新計算的速度也比把數據從HBM讀取到SRAM要快)。
https://huggingface.co/docs/text-generation-inference/conceptual/flash_attention

2. 從softmax到online softmax

??直接看flash-attention的論文比較難看明白,發現華盛頓大學的那份note寫得特別清晰,跟著它從softmax看到flash-attention會比較容易。

2.1 safe-softmax

??首先是safe的softmax計算方式。原始的softmax,對于N個數:
s o f t m a x ( { x 1 , . . . , x N } ) = { e x i ∑ j = 1 N e x j } i = 1 N softmax(\{x_1,...,x_N\})=\left\{\frac{e^{x_i}}{\sum_{j=1}^{N}e^{x_j}}\right\}_{i=1}^{N} softmax({x1?,...,xN?})={j=1N?exj?exi??}i=1N?
??對于FP16,最大能表示的數據為65536,當 x > = 11 x>=11 x>=11時, e x e^x ex就會超過FP16的最大表示范圍影響結果的正確性。為了避免這個問題,SafeSoftmax 通過減去輸入向量中的最大值來調整輸入,使得最大的指數項變為 e 0 = 1 e^0=1 e0=1從而防止了上溢的發生。同時,由于所有的指數項都除以同一個數,它們的比例關系不會改變,因此也不會影響最終的概率分布。
e x i ∑ j = 1 N e x j = e x i ? m ∑ j = 1 N e x j ? m , m = m a x { x j } j = 1 N \frac{e^{x_i}}{\sum_{j=1}{N}e^{x_j}}=\frac{e^{x_i-m}}{\sum_{j=1}{N}e^{x_j-m}}, \quad m=max\left\{x_j\right\}_{j=1}^{N} j=1?Nexj?exi??=j=1?Nexj??mexi??m?,m=max{xj?}j=1N?

2.2 3-pass safe softmax

  • 對于一個行向量 { x i } i = 1 N \{x_i\}_{i=1}^N {xi?}i=1N?,最直白的softmax計算方式是直接for循環

在這里插入圖片描述
??這個算法計算softmax需要執行3次從1->N的循環,在attention中, { x i } \{x_i\} {xi?} Q K T QK^T QKT的結果,但是如果SRAM里面存不下這個大的矩陣,上面的計算過程,就需要從HBM里面加載3次 { x i } \{x_i\} {xi?},時間花在了數據讀寫上。

2.3 Online softmax

??如果能把上面(7)(8)(9)這3個式子的計算放一個for循環,就只需要一次load數據。但是 m N m_N mN?是全局最大值,計算 m N m_N mN?就已經需要一次遍歷了。
??Online softmax算法把(7)(8)進行了合并,把3次遍歷縮減為2個。它提出計算 d i ′ = ∑ j = 1 i e x j ? m i d_i^{\prime}=\sum_{j=1}^{i}e^{x_j-m_i} di?=j=1i?exj??mi?來代替計算 d i d_i di?,當算到最后 i = N i=N i=N時會發現, d N = d N ′ d_N=d_N^{\prime} dN?=dN?。具體的,迭代計算 d i ′ d_i^{\prime} di?的方式為:
d i ′ = ∑ j = 1 i e x j ? m i = ( ∑ j = 1 i ? 1 e x j ? m i ) + e x i ? m i = ( ∑ j = 1 i ? 1 e x j ? m i ? 1 ) e m i ? 1 ? m i + e x i ? m i = d i ? 1 ′ e m i ? 1 ? m i + e x i ? m i \begin{aligned} d_i^{\prime} &= \sum_{j=1}^{i} e^{x_j - m_i} \\ &= \left( \sum_{j=1}^{i-1} e^{x_j - m_i} \right) + e^{x_i - m_i} \\ &= \left( \sum_{j=1}^{i-1} e^{x_j - m_{i-1}} \right) e^{m_{i-1} - m_i} + e^{x_i - m_i} \\ &= d_{i-1}^{\prime} e^{m_{i-1} - m_i} + e^{x_i - m_i} \end{aligned} di??=j=1i?exj??mi?=(j=1i?1?exj??mi?)+exi??mi?=(j=1i?1?exj??mi?1?)emi?1??mi?+exi??mi?=di?1?emi?1??mi?+exi??mi??

??所以就可以用迭代的方式,在找最大值 m N m_N mN?的時候,同時來計算 d i ′ d_i^{\prime} di?,把(7)和(8)一起計算,這樣只需要加載兩次 x i x_i xi?

在這里插入圖片描述

2.4 Flash-attention

??上面的online softmax仍然需要2個for循環,加載2次 x i x_i xi?來完成softmax的計算。完成softmax的計算,沒法更進一步地壓縮到1次遍歷。但是attention計算的最終目標是獲取輸出結果,也就是注意力分數與 V V V相乘的結果 O = A × V O=A \times V O=A×V,計算 O O O可以通過一次遍歷完成。
在這里插入圖片描述
??可以使用類似online softmax把計算 d i d_i di?變成計算 d i ′ d_i^{\prime} di?的方式,把 o i o_i oi?的計算也改成迭代式的,首先把 a i a_i ai?帶入 o i o_i oi?的表達式
o i = ∑ j = 1 i ( e x j ? m N d N ′ V [ j , : ] ) o_i=\sum_{j=1}^{i}\left(\frac{e^{x_j-m_{N}}}{d_N^{\prime}}V[j,:]\right) oi?=j=1i?(dN?exj??mN??V[j,:])

??可以找到一個 o i ′ o_i^{\prime} oi?,它不依賴于全局的 d N ′ d_N^{\prime} dN? m N m_N mN?
o i ′ = ∑ j = 1 i ( e x j ? m i d i ′ V [ j , : ] ) o_i^{\prime}=\sum_{j=1}^{i}\left(\frac{e^{x_j-m_{i}}}{d_i^{\prime}}V[j,:]\right) oi?=j=1i?(di?exj??mi??V[j,:])

??對于 o i ′ o_i^{\prime} oi?的計算可以使用迭代的方式,同樣的是有 o N = o N ′ o_N=o_N^{\prime} oN?=oN?
o i ′ = ∑ j = 1 i e x j ? m i d i ′ V [ j , : ] = ( ∑ j = 1 i ? 1 e x j ? m i d i ′ V [ j , : ] ) + e x i ? m i d i ′ V [ i , : ] = ( ∑ j = 1 i ? 1 e x j ? m i ? 1 d i ? 1 ′ e x j ? m i e x j ? m i ? 1 d i ? 1 ′ d i ′ V [ j , : ] ) + e x i ? m i d i ′ V [ i , : ] = ( ∑ j = 1 i ? 1 e x j ? m i ? 1 d i ? 1 ′ V [ j , : ] ) d i ? 1 ′ d i ′ e m i ? 1 ? m i + e x i ? m i d i ′ V [ i , : ] = o i ? 1 ′ d i ? 1 ′ e m i ? 1 ? m i d i ′ + e x i ? m i d i ′ V [ i , : ] \begin{aligned} o_i' &= \sum_{j=1}^{i} \frac{e^{x_j - m_i}}{d_i'} V[j,:] \\ &= \left( \sum_{j=1}^{i-1} \frac{e^{x_j - m_i}}{d_i'} V[j,:] \right) + \frac{e^{x_i - m_i}}{d_i'} V[i,:] \\ &= \left( \sum_{j=1}^{i-1} \frac{e^{x_j - m_{i-1}}}{d_{i-1}'} \frac{e^{x_j - m_i}}{e^{x_j - m_{i-1}}} \frac{d_{i-1}'}{d_i'} V[j,:] \right) + \frac{e^{x_i - m_i}}{d_i'} V[i,:] \\ &= \left( \sum_{j=1}^{i-1} \frac{e^{x_j - m_{i-1}}}{d_{i-1}'} V[j,:] \right) \frac{d_{i-1}'}{d_i'} e^{m_{i-1} - m_i} + \frac{e^{x_i - m_i}}{d_i'} V[i,:] \\ &= o_{i-1}' \frac{d_{i-1}' e^{m_{i-1} - m_i}}{d_i'} + \frac{e^{x_i - m_i}}{d_i'} V[i,:] \end{aligned} oi??=j=1i?di?exj??mi??V[j,:]=(j=1i?1?di?exj??mi??V[j,:])+di?exi??mi??V[i,:]=(j=1i?1?di?1?exj??mi?1??exj??mi?1?exj??mi??di?di?1??V[j,:])+di?exi??mi??V[i,:]=(j=1i?1?di?1?exj??mi?1??V[j,:])di?di?1??emi?1??mi?+di?exi??mi??V[i,:]=oi?1?di?di?1?emi?1??mi??+di?exi??mi??V[i,:]?

??這樣計算attention的輸出結果可以只進行一次遍歷就完成
在這里插入圖片描述

2.5 Flash-attention tiling

??上面是每次計算一個元素 [ i ] [i] [i],實際上可以一次讀取一個大小為b的塊(tile)來計算

在這里插入圖片描述在這里插入圖片描述

??此外,在flash-attention的paper里面,對 Q Q Q K K K V V V O O O分塊,其中 Q Q Q
O O O每塊大小為 m i n ( M / 4 d , d ) × d min(M/4d,d) \times d min(M/4d,d)×d K / V K/V K/V的每塊大小為 M / 4 d × d M/4d \times d M/4d×d,加起來正好不會超過SRAM的大小M,完整的算法在paper中:
在這里插入圖片描述

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/web/82710.shtml
繁體地址,請注明出處:http://hk.pswp.cn/web/82710.shtml
英文地址,請注明出處:http://en.pswp.cn/web/82710.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

python打卡day46@浙大疏錦行

知識點回顧&#xff1a; 不同CNN層的特征圖&#xff1a;不同通道的特征圖什么是注意力&#xff1a;注意力家族&#xff0c;類似于動物園&#xff0c;都是不同的模塊&#xff0c;好不好試了才知道。通道注意力&#xff1a;模型的定義和插入的位置通道注意力后的特征圖和熱力圖 內…

JavaSec-SPEL - 表達式注入

簡介 SPEL(Spring Expression Language)&#xff1a;SPEL是Spring表達式語言&#xff0c;允許在運行時動態查詢和操作對象屬性、調用方法等&#xff0c;類似于Struts2中的OGNL表達式。當參數未經過濾時&#xff0c;攻擊者可以注入惡意的SPEL表達式&#xff0c;從而執行任意代碼…

SpringCloud——OpenFeign

概述&#xff1a; OpenFeign是基于Spring的聲明式調用的HTTP客戶端&#xff0c;大大簡化了編寫Web服務客戶端的過程&#xff0c;用于快速構建http請求調用其他服務模塊。同時也是spring cloud默認選擇的服務通信工具。 使用方法&#xff1a; RestTemplate手動構建: // 帶查詢…

【深入學習Linux】System V共享內存

目錄 前言 一、共享內存是什么&#xff1f; 共享內存實現原理 共享內存細節理解 二、接口認識 1.shmget函數——申請共享內存 2.ftok函數——生成key值 再次理解ftok和shmget 1&#xff09;key與shmid的區別與聯系 2&#xff09;再理解key 3&#xff09;通過指令查看/釋放系統中…

探索 Java 垃圾收集:對象存活判定、回收流程與內存策略

個人主頁-愛因斯晨 文章專欄-JAVA學習筆記 熱門文章-賽博算命 一、引言 在 Java 技術體系里&#xff0c;垃圾收集器&#xff08;Garbage Collection&#xff0c;GC&#xff09;與內存分配策略是自動內存管理的核心支撐。深入探究其原理與機制&#xff0c;對優化程序內存性能…

hbase資源和數據權限控制

hbase適合大數據量下點查 https://zhuanlan.zhihu.com/p/471133280 HBase支持對User、NameSpace和Table進行請求數和流量配額限制&#xff0c;限制頻率可以按sec、min、hour、day 對于請求大小限制示例&#xff08;5K/sec,10M/min等&#xff09;&#xff0c;請求大小限制單位如…

大數據-275 Spark MLib - 基礎介紹 機器學習算法 集成學習 隨機森林 Bagging Boosting

點一下關注吧&#xff01;&#xff01;&#xff01;非常感謝&#xff01;&#xff01;持續更新&#xff01;&#xff01;&#xff01; 大模型篇章已經開始&#xff01; 目前已經更新到了第 22 篇&#xff1a;大語言模型 22 - MCP 自動操作 FigmaCursor 自動設計原型 Java篇開…

Delphi 實現遠程連接 Access 數據庫的指南

方法一&#xff1a;通過局域網共享 Access 文件&#xff08;簡單但有限&#xff09; 步驟 1&#xff1a;共享 Access 數據庫 將 .mdb 或 .accdb 文件放在局域網內某臺電腦的共享文件夾中。 右鍵文件夾 → 屬性 → 共享 → 啟用共享并設置權限&#xff08;需允許網絡用戶讀寫&a…

VR視頻制作有哪些流程?

VR視頻制作流程知識 VR視頻制作&#xff0c;作為融合了創意與技術的復雜制作過程&#xff0c;涵蓋從初步策劃到最終呈現的多個環節。在這個過程中&#xff0c;我們可以結合眾趣科技的產品&#xff0c;解析每一環節的實現與優化&#xff0c;揭示背后的奧秘。 VR視頻制作有哪些…

文件上傳/下載接口開發

接口特性 文件傳輸接口與傳統接口的核心差異體現在數據傳輸格式&#xff1a; 上傳接口采用 multipart/form-data 格式支持二進制文件傳輸下載接口接收二進制流并實現本地文件存儲 文件上傳接口開發 接口規范 請求地址&#xff1a;/createbyfile 請求方式&#xff1a;POST…

深入學習RabbitMQ隊列的知識

目錄 1、AMQP協議 1.1、介紹 1.2、AMQP的特點 1.3、工作流程 1.4、消息模型 1.5、消息結構 1.6、AMQP 的交換器類型 2、RabbitMQ結構介紹 2.1、核心組件 2.2、最大特點 2.3、工作原理 3、消息可靠性保障 3.1、生產端可靠性 1、生產者確認機制 2、持久化消息 3.…

【計算機網絡】NAT、代理服務器、內網穿透、內網打洞、局域網中交換機

&#x1f525;個人主頁&#x1f525;&#xff1a;孤寂大仙V &#x1f308;收錄專欄&#x1f308;&#xff1a;計算機網絡 &#x1f339;往期回顧&#x1f339;&#xff1a;【計算機網絡】數據鏈路層——ARP協議 &#x1f516;流水不爭&#xff0c;爭的是滔滔不息 一、網絡地址轉…

[論文閱讀] 人工智能 | 大語言模型計劃生成的新范式:基于過程挖掘的技能學習

#論文閱讀# 大語言模型計劃生成的新范式&#xff1a;基于過程挖掘的技能學習 論文信息 Skill Learning Using Process Mining for Large Language Model Plan Generation Andrei Cosmin Redis, Mohammadreza Fani Sani, Bahram Zarrin, Andrea Burattin Cite as: arXiv:2410.…

C文件操作2

五、文件的隨機讀寫 這些函數都需要包含頭文件 #include<stdio.h> 5.1 fseek 根據文件指針的位置和偏移量來定位文件指針&#xff08;文件內容的光標&#xff09; &#xff08;重新定位流位置指示器&#xff09; int fseek ( FILE * stream, long int offset, int or…

react私有樣式處理

react私有樣式處理 Nav.jsx Menu.jsx vue中通過scoped來實現樣式私有化。加上scoped&#xff0c;就屬于當前組件的私有樣式。 給視圖中的元素都加了一個屬性data-v-xxx&#xff0c;然后給這些樣式都加上屬性選擇器。&#xff08;deep就是不加屬性也不加屬性選擇器&#xff09; …

【信創-k8s】海光/兆芯+銀河麒麟V10離線部署k8s1.31.8+kubesphere4.1.3

? KubeSphere V4已經開源半年多&#xff0c;而且v4.1.3也已經出來了&#xff0c;修復了眾多bug。介于V4優秀的LuBan架構&#xff0c;核心組件非常少&#xff0c;資源占用也顯著降低&#xff0c;同時帶來眾多功能和便利性。我們決定與時俱進&#xff0c;使用1.30版本的Kubernet…

單片機內部結構基礎知識 FLASH相關解讀

一、總線簡單說明 地址總線、控制總線、數據總線 什么是8位8051框架結構的微控制器&#xff1f; 數據總線寬度為8位&#xff0c;即CPU一次處理或傳輸的數據量為8位&#xff08;1字節&#xff09; 同時還有一個16位的地址總線&#xff0c;這個地方也剛好對應了為什么能看到內存…

HTTPS加密的介紹

HTTPS&#xff08;HyperText Transfer Protocol Secure&#xff0c;超文本傳輸安全協議&#xff09;是HTTP協議的安全版本。它在HTTP的基礎上加入了SSL/TLS協議&#xff0c;用于對數據進行加密&#xff0c;并確保數據傳輸過程中的機密性、完整性和身份驗證。 在HTTPS出現之前&a…

【freertos-kernel】stream_buffer

文章目錄 補充任務通知發送處理ulTaskGenericNotifyTakexTaskGenericNotifyWait 清除xTaskGenericNotifyStateClearulTaskGenericNotifyValueClear 結構體StreamBufferHandle_tStreamBufferCallbackFunction_t 創建xStreamBufferGenericCreatestream buffer的類型 刪除vStreamB…

在word中點擊zotero Add/Edit Citation沒有反應的解決辦法

重新安裝了word插件 1.關掉word 2.進入Zotero左上角編輯-引用 3.往下滑找到Microsoft Word&#xff0c;點重新安裝加載項