SDPA(Scaled Dot-Product Attention)詳解

SDPA(Scaled Dot-Product Attention)詳解

SDPA(Scaled Dot-Product Attention,縮放點積注意力)是 Transformer 模型的核心計算單元,最早由 Vaswani 等人在 2017 年的論文《Attention Is All You Need》提出。它通過計算查詢(Query)、鍵(Key)和值(Value)之間的相似度,生成上下文感知的表示。


1. SDPA 的數學定義

給定:

  • 查詢矩陣(Query) Q ∈ R n × d k Q \in \mathbb{R}^{n \times d_k} QRn×dk?
  • 鍵矩陣(Key) K ∈ R m × d k K \in \mathbb{R}^{m \times d_k} KRm×dk?
  • 值矩陣(Value) V ∈ R m × d v V \in \mathbb{R}^{m \times d_v} VRm×dv?

SDPA 的計算公式為:

Attention ( Q , K , V ) = softmax ( Q K T d k ) V \text{Attention}(Q, K, V) = \text{softmax}\left( \frac{QK^T}{\sqrt{d_k}} \right) V Attention(Q,K,V)=softmax(dk? ?QKT?)V

其中:

  • Q K T QK^T QKT 計算查詢和鍵的點積(相似度)。
  • d k \sqrt{d_k} dk? ? 用于縮放點積,防止梯度消失或爆炸(尤其是 d k d_k dk? 較大時)。
  • softmax 將注意力權重歸一化為概率分布。
  • 最終加權求和 V V V 得到輸出。

2. SDPA 的計算步驟

  1. 計算相似度(Dot-Product)
  • 計算 Q Q Q K K K 的點積:
    S = Q K T S = QK^T S=QKT
  • 相似度矩陣 S ∈ R n × m S \in \mathbb{R}^{n \times m} SRn×m 表示每個查詢對所有鍵的匹配程度。
  1. 縮放(Scaling)

    • 除以 d k \sqrt{d_k} dk? ?(鍵向量的維度),防止點積值過大導致 softmax 梯度消失:
      S scaled = S d k S_{\text{scaled}} = \frac{S}{\sqrt{d_k}} Sscaled?=dk? ?S?
  2. Softmax 歸一化

    • 對每行(每個查詢)做 softmax,得到注意力權重 A A A
      A = softmax ( S scaled ) A = \text{softmax}(S_{\text{scaled}}) A=softmax(Sscaled?)
    • 保證 ∑ j A i , j = 1 \sum_j A_{i,j} = 1 j?Ai,j?=1,權重總和為 1。
  3. 加權求和(Value 聚合)

    • 用注意力權重 A A A V V V 加權求和,得到最終輸出:
      Output = A ? V \text{Output} = A \cdot V Output=A?V
    • 輸出維度: R n × d v \mathbb{R}^{n \times d_v} Rn×dv?

3. SDPA 的作用與優勢

? 核心作用

  • 讓模型動態關注輸入的不同部分(類似人類注意力機制)。
  • 適用于序列數據(如文本、語音、視頻),捕捉長距離依賴。

? 優勢

  1. 并行計算友好
  • 矩陣乘法(GEMM)可高效并行加速(GPU/TPU 優化)。
  1. 可解釋性
    • 注意力權重可視化(如 BertViz)可分析模型關注哪些 token。
  2. 靈活擴展
    • 可結合 多頭注意力(Multi-Head Attention) 增強表達能力。

4. SDPA 的變體與優化

變體/優化核心改進應用場景
多頭注意力(MHA)并行多個 SDPA,增強特征多樣性Transformer (BERT, GPT)
FlashAttention優化內存訪問,減少 HBM 讀寫長序列推理(如 8K+ tokens)
Sparse Attention只計算局部或稀疏的注意力降低計算復雜度(如 Longformer)
Linear Attention用線性近似替代 softmax低資源設備(如 RetNet)

5. 代碼實現(PyTorch 示例)

import torch
import torch.nn.functional as Fdef scaled_dot_product_attention(Q, K, V, mask=None):d_k = Q.size(-1)scores = torch.matmul(Q, K.transpose(-2, -1)) / (d_k ** 0.5)if mask is not None:scores = scores.masked_fill(mask == 0, -1e9)attn_weights = F.softmax(scores, dim=-1)output = torch.matmul(attn_weights, V)return output# 示例輸入
Q = torch.randn(2, 5, 64)  # (batch_size, seq_len, d_k)
K = torch.randn(2, 5, 64)
V = torch.randn(2, 5, 128)
output = scaled_dot_product_attention(Q, K, V)
print(output.shape)  # torch.Size([2, 5, 128])

6. 總結

  • SDPA 是 Transformer 的基石,通過 Query-Key-Value 機制 + Softmax 歸一化 實現動態注意力。
  • 關鍵優化點:縮放(防止梯度問題)、并行計算、內存效率(如 FlashAttention)
  • 現代優化(如 SageAttention2)進一步結合 量化、稀疏化、離群值處理 提升效率。

SDPA 及其變體已成為 NLP、CV、多模態領域的核心組件,理解其原理對模型優化至關重要。

SDPA計算過程舉例

我們通過一個具體的數值例子,逐步演示 SDPA 的計算過程。假設輸入如下(簡化版,便于手動計算):

輸入數據(假設 d_k = 2, d_v = 3
  • Query (Q):2 個查詢(n=2),每個查詢維度 d_k=2
    Q = [ 1 2 3 4 ] Q = \begin{bmatrix} 1 & 2 \\ 3 & 4 \\ \end{bmatrix} Q=[13?24?]
  • Key (K):3 個鍵(m=3),每個鍵維度 d_k=2
    K = [ 5 6 7 8 9 10 ] K = \begin{bmatrix} 5 & 6 \\ 7 & 8 \\ 9 & 10 \\ \end{bmatrix} K= ?579?6810? ?
  • Value (V):3 個值(m=3),每個值維度 d_v=3
    V = [ 1 0 1 0 1 0 1 1 0 ] V = \begin{bmatrix} 1 & 0 & 1 \\ 0 & 1 & 0 \\ 1 & 1 & 0 \\ \end{bmatrix} V= ?101?011?100? ?

Step 1: 計算 Query 和 Key 的點積(Dot-Product)

計算 S = Q K T S = QK^T S=QKT

Q K T = [ 1 ? 5 + 2 ? 6 1 ? 7 + 2 ? 8 1 ? 9 + 2 ? 10 3 ? 5 + 4 ? 6 3 ? 7 + 4 ? 8 3 ? 9 + 4 ? 10 ] = [ 5 + 12 7 + 16 9 + 20 15 + 24 21 + 32 27 + 40 ] = [ 17 23 29 39 53 67 ] QK^T = \begin{bmatrix} 1 \cdot 5 + 2 \cdot 6 & 1 \cdot 7 + 2 \cdot 8 & 1 \cdot 9 + 2 \cdot 10 \\ 3 \cdot 5 + 4 \cdot 6 & 3 \cdot 7 + 4 \cdot 8 & 3 \cdot 9 + 4 \cdot 10 \\ \end{bmatrix} = \begin{bmatrix} 5+12 & 7+16 & 9+20 \\ 15+24 & 21+32 & 27+40 \\ \end{bmatrix} = \begin{bmatrix} 17 & 23 & 29 \\ 39 & 53 & 67 \\ \end{bmatrix} QKT=[1?5+2?63?5+4?6?1?7+2?83?7+4?8?1?9+2?103?9+4?10?]=[5+1215+24?7+1621+32?9+2027+40?]=[1739?2353?2967?]


Step 2: 縮放(Scaling)

除以 d k = 2 ≈ 1.414 \sqrt{d_k} = \sqrt{2} \approx 1.414 dk? ?=2 ?1.414

S scaled = S 2 = [ 17 / 1.414 23 / 1.414 29 / 1.414 39 / 1.414 53 / 1.414 67 / 1.414 ] ≈ [ 12.02 16.26 20.51 27.58 37.48 47.38 ] S_{\text{scaled}} = \frac{S}{\sqrt{2}} = \begin{bmatrix} 17/1.414 & 23/1.414 & 29/1.414 \\ 39/1.414 & 53/1.414 & 67/1.414 \\ \end{bmatrix} \approx \begin{bmatrix} 12.02 & 16.26 & 20.51 \\ 27.58 & 37.48 & 47.38 \\ \end{bmatrix} Sscaled?=2 ?S?=[17/1.41439/1.414?23/1.41453/1.414?29/1.41467/1.414?][12.0227.58?16.2637.48?20.5147.38?]


Step 3: Softmax 歸一化(計算注意力權重)

對每一行(每個 Query)做 softmax:

$\text{softmax}([12.02, 16.26, 20.51]) \approx [2.06 \times 10^{-4}, 0.016, 0.984] $
$\text{softmax}([27.58, 37.48, 47.38]) \approx [1.67 \times 10^{-9}, 0.0001, 0.9999] $

因此,注意力權重矩陣 A A A 為:

A ≈ [ 2.06 × 10 ? 4 0.016 0.984 1.67 × 10 ? 9 0.0001 0.9999 ] A \approx \begin{bmatrix} 2.06 \times 10^{-4} & 0.016 & 0.984 \\ 1.67 \times 10^{-9} & 0.0001 & 0.9999 \\ \end{bmatrix} A[2.06×10?41.67×10?9?0.0160.0001?0.9840.9999?]

解釋

  • 第 1 個 Query 主要關注第 3 個 Key(權重 0.984)。
  • 第 2 個 Query 幾乎只關注第 3 個 Key(權重 0.9999)。

Step 4: 加權求和(聚合 Value)

計算 Output = A ? V \text{Output} = A \cdot V Output=A?V

Output = [ 2.06 × 10 ? 4 ? 1 + 0.016 ? 0 + 0.984 ? 1 2.06 × 10 ? 4 ? 0 + 0.016 ? 1 + 0.984 ? 1 2.06 × 10 ? 4 ? 1 + 0.016 ? 0 + 0.984 ? 0 ] T ≈ [ 0.984 1.000 0.0002 ] T \text{Output} = \begin{bmatrix} 2.06 \times 10^{-4} \cdot 1 + 0.016 \cdot 0 + 0.984 \cdot 1 \\ 2.06 \times 10^{-4} \cdot 0 + 0.016 \cdot 1 + 0.984 \cdot 1 \\ 2.06 \times 10^{-4} \cdot 1 + 0.016 \cdot 0 + 0.984 \cdot 0 \\ \end{bmatrix}^T \approx \begin{bmatrix} 0.984 \\ 1.000 \\ 0.0002 \\ \end{bmatrix}^T Output= ?2.06×10?4?1+0.016?0+0.984?12.06×10?4?0+0.016?1+0.984?12.06×10?4?1+0.016?0+0.984?0? ?T ?0.9841.0000.0002? ?T

Output = [ 0.984 1.000 0.0002 0.9999 0.9999 0.0001 ] \text{Output} = \begin{bmatrix} 0.984 & 1.000 & 0.0002 \\ 0.9999 & 0.9999 & 0.0001 \\ \end{bmatrix} Output=[0.9840.9999?1.0000.9999?0.00020.0001?]

解釋

  • 第 1 行:主要聚合了第 3 個 Value [1, 1, 0],但受前兩個 Value 微弱影響。
  • 第 2 行:幾乎完全由第 3 個 Value 決定。

最終輸出

Output ≈ [ 0.984 1.000 0.0002 0.9999 0.9999 0.0001 ] \text{Output} \approx \begin{bmatrix} 0.984 & 1.000 & 0.0002 \\ 0.9999 & 0.9999 & 0.0001 \\ \end{bmatrix} Output[0.9840.9999?1.0000.9999?0.00020.0001?]


總結

  1. 點積:計算 Query 和 Key 的相似度。
  2. 縮放:防止梯度爆炸/消失。
  3. Softmax:歸一化為概率分布。
  4. 加權求和:聚合 Value 得到最終表示。

這個例子展示了 SDPA 如何動態分配注意力權重,并生成上下文感知的輸出。實際應用中(如 Transformer),還會結合 多頭注意力(Multi-Head Attention) 增強表達能力。

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/bicheng/84762.shtml
繁體地址,請注明出處:http://hk.pswp.cn/bicheng/84762.shtml
英文地址,請注明出處:http://en.pswp.cn/bicheng/84762.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

java通過hutool工具生成二維碼實現掃碼跳轉功能

實現&#xff1a; 首先引入zxing和hutool工具依賴 <dependency><groupId>com.google.zxing</groupId><artifactId>core</artifactId><version>3.5.2</version></dependency><dependency><groupId>com.google.zxi…

數據庫數據導出到Excel表格

1.后端代碼 第一步&#xff1a;UserMapper定義根據ID列表批量查詢用戶方法 // 批量查詢用戶信息List<User> selectUserByIds(List<Integer> ids); 第二步&#xff1a;UserMapper.xml寫動態SQL&#xff0c;實現批量查詢用戶 <!--根據Ids批量查詢用戶-->&l…

Altera系列FPGA基于ADV7180解碼PAL視頻,純verilog去隔行,提供2套Quartus工程源碼和技術支持

目錄 1、前言工程概述免責聲明 2、相關方案推薦我已有的所有工程源碼總目錄----方便你快速找到自己喜歡的項目Altera系列FPGA相關方案推薦我這里已有的PAL視頻解碼方案 3、設計思路框架工程設計原理框圖輸入PAL相機ADV7180芯片解讀BT656視頻解碼模塊圖像緩存架構輸出視頻格式轉…

【教程】Windows安全中心掃描設置排除文件

轉載請注明出處&#xff1a;小鋒學長生活大爆炸[xfxuezhagn.cn] 如果本文幫助到了你&#xff0c;歡迎[點贊、收藏、關注]哦~ 目錄 背景說明 解決方法 背景說明 即使已經把實時防護等設置全都關了&#xff0c;但Windows還是會不定時給你掃描&#xff0c;然后把風險軟件給刪了…

OPenCV CUDA模塊立體匹配------對立體匹配生成的視差圖進行雙邊濾波處理類cv::cuda::DisparityBilateralFilter

操作系統&#xff1a;ubuntu22.04 OpenCV版本&#xff1a;OpenCV4.9 IDE:Visual Studio Code 編程語言&#xff1a;C11 算法描述 cv::cuda::DisparityBilateralFilter 是 OpenCV CUDA 模塊中的一個類&#xff0c;用于對立體匹配生成的視差圖進行雙邊濾波處理。這種濾波方法可…

自然語言處理期末復習

自然語言處理期末復習 一單元 自然語言處理基礎 兩個核心任務&#xff1a; 自然語言理解&#xff08;NLU, Natural Language Understanding&#xff09; 讓計算機“讀懂”人類語言&#xff0c;理解文本的語義、結構和意圖。 典型子任務包括&#xff1a;分詞、詞性標注、句法分…

黃仁勛在2025年巴黎VivaTech大會上的GTC演講:AI工廠驅動的工業革命(上)

引言 2025年6月12日,在巴黎VivaTech大會上,英偉達創始人兼CEO黃仁勛發表了題為"AI工廠驅動的工業革命"的GTC主題演講。這場持續約1小時35分鐘的演講不僅詳細闡述了英偉達在AI基礎設施、智能體技術、量子計算及機器人領域的最新突破,更系統性地勾勒出了人工智能如…

DMC-E 系列總線控制卡----雷賽板卡介紹(六)

應用軟件開發方法 DMC-E 系列總線運動控制卡的應用軟件可以在 Visual Basic 、 Visual C++ 、 C# 等高級語言 環境下開發。應用軟件開發之前,需保證 DMC-E 系列總線運動控制卡連接好從站,通過控制 卡 Motion 的 EtherCAT 總線配置界面掃描從站、設置總線通信周期…

題目類型——左右逢源

1、針對的題目&#xff1a;&#xff08;不一定正確或完整&#xff09; 數據結構為數組之類的線性結構&#xff08;也許可以拓展&#xff09;&#xff0c;于是數組中每個元素和其他元素的相對關系為左右或前后需要對數組中每個元素求解或者說最終解要根據每個元素的解得出每個元…

RAG檢索前處理

1. 查詢構建&#xff08;包括Text2SQL&#xff09; 查詢構建的相關技術棧&#xff1a; Text-to-SQLText-to-Cypher 從查詢中提取元數據&#xff08;Self-query Retriever&#xff09; 1.1 Text-to-SQL&#xff08;關系數據庫&#xff09; 1.1.1 大語言模型方法Text-to-SQL樣…

OmoFun動漫官網,動漫共和國最新入口|網頁版

OmoFun 動漫&#xff0c;又叫動漫共和國&#xff0c;是一個專注于提供豐富動漫資源的在線平臺&#xff0c;深受廣大動漫愛好者的喜愛。它匯集了海量的動漫資源&#xff0c;涵蓋日本動漫、國產動漫、歐美動漫等多種類型&#xff0c;無論是最新上映的熱門番劇還是經典老番&#x…

ue5的blender4.1groom毛發插件v012安裝和使用方法(排除了沖突錯誤)

關鍵出錯不出錯是看這個文件pyalembic-1.8.8-cp311-cp311-win_amd64.whl&#xff0c;解決和Alembic SQL工具&#xff09;的加載沖突&#xff01; 其他blender版本根據其內部的python版本選擇對應的文件解壓安裝。 1、安裝插件&#xff01;把GroomExporter_v012_Blender4.1.1(原…

windows安裝jekyll

windows安裝jekyll 安裝ruby 首先需要下載ruby RubyInstaller for Windows - RubyInstaller國內鏡像站 我的操作系統是win10所以我安裝的最新版&#xff0c;你們安裝的時候&#xff0c;也可以安裝最新版&#xff0c;我這里就不附加圖片了 如果你的ruby安裝完成之后&#x…

DBever工具自適應mysql不同版本的連接

DBever工具的連接便捷性 最近在使用DBever工具連接不同版本的mysql數據庫&#xff0c;發現這個工具確實比mysql-log工具要兼容性好很多&#xff0c;直接就可以連接不同版本的數據庫&#xff0c;比如常見的mysql數據庫版本&#xff1a;8.0和5.7&#xff0c;而且鏈接成功后&…

K8S認證|CKS題庫+答案| 10. Trivy 掃描鏡像安全漏洞

目錄 10. Trivy 掃描鏡像安全漏洞 免費獲取并激活 CKA_v1.31_模擬系統 題目 開始操作&#xff1a; 1&#xff09;、切換集群 2&#xff09;、切換到master并提權 3&#xff09;、查看Pod和鏡像對應關系 4&#xff09;、查看并去重鏡像名稱 5&#xff09;、掃描所有鏡…

Rust高級抽象

Rust 的高級抽象能力是其核心優勢之一&#xff0c;允許開發者通過特征&#xff08;Traits&#xff09;、泛型&#xff08;Generics&#xff09;、閉包&#xff08;Closures&#xff09;、迭代器&#xff08;Iterators&#xff09;等機制實現高度靈活和可復用的代碼。今天我們來…

Vue里面的映射方法

111.getters配置項 112.mapstate和mapgetter 113.&#xfeff;mapActions與&#xfeff;mapMutations 114.多組件共享數據 115.vuex模塊化&#xff0c;namespaces1 116.name&#xfeff;s&#xfeff;pace2

Node.js特訓專欄-基礎篇:2. JavaScript核心知識在Node.js中的應用

我將從變量、函數、異步編程等方面入手&#xff0c;結合Node.js實際應用場景&#xff0c;為你詳細闡述JavaScript核心知識在其中的運用&#xff1a; JavaScript核心知識在Node.js中的應用 在當今的軟件開發領域&#xff0c;Node.js憑借其高效的性能和強大的功能&#xff0c;成…

負載均衡LB》》LVS

LO 接口 LVS簡介 LVS&#xff08;Linux Virtual Server&#xff09;即Linux虛擬服務器&#xff0c;是由章文嵩博士主導的開源負載均衡項目&#xff0c;通過LVS提供的負載均衡技術和Linux操作系統實現一個高性能、高可用的服務器集群&#xff0c;它具有良好可靠性、可擴展性和可…

Modbus TCP轉DeviceNet網關配置溫控儀配置案例

某工廠生產線需將Modbus TCP協議的智能儀表接入DeviceNet網絡&#xff08;主站為PLC&#xff0c;如Rockwell ControlLogix&#xff09;&#xff0c;實現集中監控。需通過開疆智能Modbus TCP轉DeviceNet網關KJ-DVCZ-MTCPS完成協議轉換。Modbus TCP設備&#xff1a;溫控器&#x…