SDPA(Scaled Dot-Product Attention)詳解
SDPA(Scaled Dot-Product Attention,縮放點積注意力)是 Transformer 模型的核心計算單元,最早由 Vaswani 等人在 2017 年的論文《Attention Is All You Need》提出。它通過計算查詢(Query)、鍵(Key)和值(Value)之間的相似度,生成上下文感知的表示。
1. SDPA 的數學定義
給定:
- 查詢矩陣(Query): Q ∈ R n × d k Q \in \mathbb{R}^{n \times d_k} Q∈Rn×dk?
- 鍵矩陣(Key): K ∈ R m × d k K \in \mathbb{R}^{m \times d_k} K∈Rm×dk?
- 值矩陣(Value): V ∈ R m × d v V \in \mathbb{R}^{m \times d_v} V∈Rm×dv?
SDPA 的計算公式為:
Attention ( Q , K , V ) = softmax ( Q K T d k ) V \text{Attention}(Q, K, V) = \text{softmax}\left( \frac{QK^T}{\sqrt{d_k}} \right) V Attention(Q,K,V)=softmax(dk??QKT?)V
其中:
- Q K T QK^T QKT 計算查詢和鍵的點積(相似度)。
- d k \sqrt{d_k} dk?? 用于縮放點積,防止梯度消失或爆炸(尤其是 d k d_k dk? 較大時)。
- softmax 將注意力權重歸一化為概率分布。
- 最終加權求和 V V V 得到輸出。
2. SDPA 的計算步驟
- 計算相似度(Dot-Product)
- 計算 Q Q Q 和 K K K 的點積:
S = Q K T S = QK^T S=QKT - 相似度矩陣 S ∈ R n × m S \in \mathbb{R}^{n \times m} S∈Rn×m 表示每個查詢對所有鍵的匹配程度。
-
縮放(Scaling)
- 除以 d k \sqrt{d_k} dk??(鍵向量的維度),防止點積值過大導致 softmax 梯度消失:
S scaled = S d k S_{\text{scaled}} = \frac{S}{\sqrt{d_k}} Sscaled?=dk??S?
- 除以 d k \sqrt{d_k} dk??(鍵向量的維度),防止點積值過大導致 softmax 梯度消失:
-
Softmax 歸一化
- 對每行(每個查詢)做 softmax,得到注意力權重 A A A:
A = softmax ( S scaled ) A = \text{softmax}(S_{\text{scaled}}) A=softmax(Sscaled?) - 保證 ∑ j A i , j = 1 \sum_j A_{i,j} = 1 ∑j?Ai,j?=1,權重總和為 1。
- 對每行(每個查詢)做 softmax,得到注意力權重 A A A:
-
加權求和(Value 聚合)
- 用注意力權重 A A A 對 V V V 加權求和,得到最終輸出:
Output = A ? V \text{Output} = A \cdot V Output=A?V - 輸出維度: R n × d v \mathbb{R}^{n \times d_v} Rn×dv?。
- 用注意力權重 A A A 對 V V V 加權求和,得到最終輸出:
3. SDPA 的作用與優勢
? 核心作用:
- 讓模型動態關注輸入的不同部分(類似人類注意力機制)。
- 適用于序列數據(如文本、語音、視頻),捕捉長距離依賴。
? 優勢:
- 并行計算友好
- 矩陣乘法(GEMM)可高效并行加速(GPU/TPU 優化)。
- 可解釋性
- 注意力權重可視化(如
BertViz
)可分析模型關注哪些 token。
- 注意力權重可視化(如
- 靈活擴展
- 可結合 多頭注意力(Multi-Head Attention) 增強表達能力。
4. SDPA 的變體與優化
變體/優化 | 核心改進 | 應用場景 |
---|---|---|
多頭注意力(MHA) | 并行多個 SDPA,增強特征多樣性 | Transformer (BERT, GPT) |
FlashAttention | 優化內存訪問,減少 HBM 讀寫 | 長序列推理(如 8K+ tokens) |
Sparse Attention | 只計算局部或稀疏的注意力 | 降低計算復雜度(如 Longformer) |
Linear Attention | 用線性近似替代 softmax | 低資源設備(如 RetNet) |
5. 代碼實現(PyTorch 示例)
import torch
import torch.nn.functional as Fdef scaled_dot_product_attention(Q, K, V, mask=None):d_k = Q.size(-1)scores = torch.matmul(Q, K.transpose(-2, -1)) / (d_k ** 0.5)if mask is not None:scores = scores.masked_fill(mask == 0, -1e9)attn_weights = F.softmax(scores, dim=-1)output = torch.matmul(attn_weights, V)return output# 示例輸入
Q = torch.randn(2, 5, 64) # (batch_size, seq_len, d_k)
K = torch.randn(2, 5, 64)
V = torch.randn(2, 5, 128)
output = scaled_dot_product_attention(Q, K, V)
print(output.shape) # torch.Size([2, 5, 128])
6. 總結
- SDPA 是 Transformer 的基石,通過 Query-Key-Value 機制 + Softmax 歸一化 實現動態注意力。
- 關鍵優化點:縮放(防止梯度問題)、并行計算、內存效率(如 FlashAttention)。
- 現代優化(如 SageAttention2)進一步結合 量化、稀疏化、離群值處理 提升效率。
SDPA 及其變體已成為 NLP、CV、多模態領域的核心組件,理解其原理對模型優化至關重要。
SDPA計算過程舉例
我們通過一個具體的數值例子,逐步演示 SDPA 的計算過程。假設輸入如下(簡化版,便于手動計算):
輸入數據(假設 d_k = 2
, d_v = 3
)
- Query (Q):2 個查詢(
n=2
),每個查詢維度d_k=2
Q = [ 1 2 3 4 ] Q = \begin{bmatrix} 1 & 2 \\ 3 & 4 \\ \end{bmatrix} Q=[13?24?] - Key (K):3 個鍵(
m=3
),每個鍵維度d_k=2
K = [ 5 6 7 8 9 10 ] K = \begin{bmatrix} 5 & 6 \\ 7 & 8 \\ 9 & 10 \\ \end{bmatrix} K= ?579?6810? ? - Value (V):3 個值(
m=3
),每個值維度d_v=3
V = [ 1 0 1 0 1 0 1 1 0 ] V = \begin{bmatrix} 1 & 0 & 1 \\ 0 & 1 & 0 \\ 1 & 1 & 0 \\ \end{bmatrix} V= ?101?011?100? ?
Step 1: 計算 Query 和 Key 的點積(Dot-Product)
計算 S = Q K T S = QK^T S=QKT:
Q K T = [ 1 ? 5 + 2 ? 6 1 ? 7 + 2 ? 8 1 ? 9 + 2 ? 10 3 ? 5 + 4 ? 6 3 ? 7 + 4 ? 8 3 ? 9 + 4 ? 10 ] = [ 5 + 12 7 + 16 9 + 20 15 + 24 21 + 32 27 + 40 ] = [ 17 23 29 39 53 67 ] QK^T = \begin{bmatrix} 1 \cdot 5 + 2 \cdot 6 & 1 \cdot 7 + 2 \cdot 8 & 1 \cdot 9 + 2 \cdot 10 \\ 3 \cdot 5 + 4 \cdot 6 & 3 \cdot 7 + 4 \cdot 8 & 3 \cdot 9 + 4 \cdot 10 \\ \end{bmatrix} = \begin{bmatrix} 5+12 & 7+16 & 9+20 \\ 15+24 & 21+32 & 27+40 \\ \end{bmatrix} = \begin{bmatrix} 17 & 23 & 29 \\ 39 & 53 & 67 \\ \end{bmatrix} QKT=[1?5+2?63?5+4?6?1?7+2?83?7+4?8?1?9+2?103?9+4?10?]=[5+1215+24?7+1621+32?9+2027+40?]=[1739?2353?2967?]
Step 2: 縮放(Scaling)
除以 d k = 2 ≈ 1.414 \sqrt{d_k} = \sqrt{2} \approx 1.414 dk??=2?≈1.414:
S scaled = S 2 = [ 17 / 1.414 23 / 1.414 29 / 1.414 39 / 1.414 53 / 1.414 67 / 1.414 ] ≈ [ 12.02 16.26 20.51 27.58 37.48 47.38 ] S_{\text{scaled}} = \frac{S}{\sqrt{2}} = \begin{bmatrix} 17/1.414 & 23/1.414 & 29/1.414 \\ 39/1.414 & 53/1.414 & 67/1.414 \\ \end{bmatrix} \approx \begin{bmatrix} 12.02 & 16.26 & 20.51 \\ 27.58 & 37.48 & 47.38 \\ \end{bmatrix} Sscaled?=2?S?=[17/1.41439/1.414?23/1.41453/1.414?29/1.41467/1.414?]≈[12.0227.58?16.2637.48?20.5147.38?]
Step 3: Softmax 歸一化(計算注意力權重)
對每一行(每個 Query)做 softmax:
$\text{softmax}([12.02, 16.26, 20.51]) \approx [2.06 \times 10^{-4}, 0.016, 0.984] $
$\text{softmax}([27.58, 37.48, 47.38]) \approx [1.67 \times 10^{-9}, 0.0001, 0.9999] $
因此,注意力權重矩陣 A A A 為:
A ≈ [ 2.06 × 10 ? 4 0.016 0.984 1.67 × 10 ? 9 0.0001 0.9999 ] A \approx \begin{bmatrix} 2.06 \times 10^{-4} & 0.016 & 0.984 \\ 1.67 \times 10^{-9} & 0.0001 & 0.9999 \\ \end{bmatrix} A≈[2.06×10?41.67×10?9?0.0160.0001?0.9840.9999?]
解釋:
- 第 1 個 Query 主要關注第 3 個 Key(權重 0.984)。
- 第 2 個 Query 幾乎只關注第 3 個 Key(權重 0.9999)。
Step 4: 加權求和(聚合 Value)
計算 Output = A ? V \text{Output} = A \cdot V Output=A?V:
Output = [ 2.06 × 10 ? 4 ? 1 + 0.016 ? 0 + 0.984 ? 1 2.06 × 10 ? 4 ? 0 + 0.016 ? 1 + 0.984 ? 1 2.06 × 10 ? 4 ? 1 + 0.016 ? 0 + 0.984 ? 0 ] T ≈ [ 0.984 1.000 0.0002 ] T \text{Output} = \begin{bmatrix} 2.06 \times 10^{-4} \cdot 1 + 0.016 \cdot 0 + 0.984 \cdot 1 \\ 2.06 \times 10^{-4} \cdot 0 + 0.016 \cdot 1 + 0.984 \cdot 1 \\ 2.06 \times 10^{-4} \cdot 1 + 0.016 \cdot 0 + 0.984 \cdot 0 \\ \end{bmatrix}^T \approx \begin{bmatrix} 0.984 \\ 1.000 \\ 0.0002 \\ \end{bmatrix}^T Output= ?2.06×10?4?1+0.016?0+0.984?12.06×10?4?0+0.016?1+0.984?12.06×10?4?1+0.016?0+0.984?0? ?T≈ ?0.9841.0000.0002? ?T
Output = [ 0.984 1.000 0.0002 0.9999 0.9999 0.0001 ] \text{Output} = \begin{bmatrix} 0.984 & 1.000 & 0.0002 \\ 0.9999 & 0.9999 & 0.0001 \\ \end{bmatrix} Output=[0.9840.9999?1.0000.9999?0.00020.0001?]
解釋:
- 第 1 行:主要聚合了第 3 個 Value
[1, 1, 0]
,但受前兩個 Value 微弱影響。- 第 2 行:幾乎完全由第 3 個 Value 決定。
最終輸出
Output ≈ [ 0.984 1.000 0.0002 0.9999 0.9999 0.0001 ] \text{Output} \approx \begin{bmatrix} 0.984 & 1.000 & 0.0002 \\ 0.9999 & 0.9999 & 0.0001 \\ \end{bmatrix} Output≈[0.9840.9999?1.0000.9999?0.00020.0001?]
總結
- 點積:計算 Query 和 Key 的相似度。
- 縮放:防止梯度爆炸/消失。
- Softmax:歸一化為概率分布。
- 加權求和:聚合 Value 得到最終表示。
這個例子展示了 SDPA 如何動態分配注意力權重,并生成上下文感知的輸出。實際應用中(如 Transformer),還會結合 多頭注意力(Multi-Head Attention) 增強表達能力。