Scaled Dot-Product Attention
-
論文地址
https://arxiv.org/pdf/1706.03762
注意力機制介紹
-
縮放點積注意力是Transformer模型的核心組件,用于計算序列中不同位置之間的關聯程度。其核心思想是通過查詢向量(query)和鍵向量(key)的點積來獲取注意力分數,再通過縮放和歸一化處理,最后與值向量(value)加權求和得到最終表示。
?
數學公式
-
縮放點積注意力的計算過程可分為三個關鍵步驟:
- 點積計算與縮放:通過矩陣乘法計算查詢向量與鍵向量的相似度,并使用 d k \sqrt{d_k} dk?? 縮放
- 掩碼處理(可選):對需要忽略的位置施加極大負值掩碼
- Softmax歸一化:將注意力分數轉換為概率分布
- 加權求和:用注意力權重對值向量進行加權
公式表達為:
Attention ( Q , K , V ) = softmax ( Q K T d k ) V \text{Attention}(Q,K,V) = \text{softmax}\left( \frac{QK^T}{\sqrt{d_k}} \right)V Attention(Q,K,V)=softmax(dk??QKT?)V
其中:- Q ∈ R s e q _ l e n × d _ k Q \in \mathbb{R}^{seq\_len \times d\_k} Q∈Rseq_len×d_k:查詢矩陣
- K ∈ R s e q _ l e n × d _ k K \in \mathbb{R}^{seq\_len \times d\_k} K∈Rseq_len×d_k:鍵矩陣
- V ∈ R s e q _ l e n × d _ k V \in \mathbb{R}^{seq\_len \times d\_k} V∈Rseq_len×d_k:值矩陣
s e q _ l e n seq\_len seq_len 為序列長度, d _ k d\_k d_k 為embedding的維度。
代碼實現
-
計算注意力分數
#!/usr/bin/env python # -*- coding: utf-8 -*- import torchdef calculate_attention(query, key, value, mask=None):"""計算縮放點積注意力分數參數說明:query: [batch_size, n_heads, seq_len, d_k]key: [batch_size, n_heads, seq_len, d_k] value: [batch_size, n_heads, seq_len, d_k]mask: [batch_size, seq_len, seq_len](可選)"""d_k = key.shape[-1]key_transpose = key.transpose(-2, -1) # 轉置最后兩個維度# 計算縮放點積 [batch, h, seq_len, seq_len]att_scaled = torch.matmul(query, key_transpose) / d_k ** 0.5# 掩碼處理(解碼器自注意力使用)if mask is not None:att_scaled = att_scaled.masked_fill_(mask=mask, value=-1e9)# Softmax歸一化att_softmax = torch.softmax(att_scaled, dim=-1)# 加權求和 [batch, h, seq_len, d_k]return torch.matmul(att_softmax, value)
-
相關解釋
-
輸入張量 query, key, value的形狀
如果是直接計算的話,那么shape是 [batch_size, seq_len, d_model]
當然為了學習更多的表征,一般都是多頭注意力,這時候shape則是[batch_size, n_heads, seq_len, d_k]
其中
-
batch_size:批量
-
n_heads:注意力頭的數量
-
seq_len: 序列的長度
-
d_model: embedding維度
-
d_k: d_k = d_model / n_heads
-
-
代碼中的shape轉變
-
key_transpose :key的轉置矩陣
由 key 轉置了最后兩個維度,維度從 [batch_size, n_heads, seq_len, d_k] 轉變為 [batch_size, n_heads, d_k, seq_len]
-
**att_scaled **:縮放點積
由 query 和 key 通過矩陣相乘得到
[batch_size, n_heads, seq_len, d_k] @ [batch_size, n_heads, d_k, seq_len] --> [batch_size, n_heads, seq_len, seq_len]
-
att_score: 注意力分數
由兩個矩陣相乘得到
[batch_size, n_heads, seq_len, seq_len] @ [batch_size, n_heads, seq_len, d_k] --> [batch_size, n_heads, seq_len, d_k]
-
-
使用示例
-
測試代碼
if __name__ == "__main__":# 模擬輸入:batch_size=2, 8個注意力頭,序列長度512,d_k=64x = torch.ones((2, 8, 512, 64))# 計算注意力(未使用掩碼)att_score = calculate_attention(x, x, x)print("輸出形狀:", att_score.shape) # torch.Size([2, 8, 512, 64])print("注意力分數示例:\n", att_score[0,0,:3,:3])
在實際使用中通常會將此實現封裝為
nn.Module
并與位置編碼、殘差連接等組件配合使用,構建完整的Transformer層。