MHSA:使用 PyTorch 實現的多頭自注意力 (Multi-head Self Attention) 代碼示例,包含詳細注釋說明:
-
線性投影
通過三個線性層分別生成查詢(Q)、鍵(K)、值(V)矩陣:
Q=Wq?x,K=Wk?x,V=Wv?xQ = W_q·x, \quad K = W_k·x, \quad V = W_v·xQ=Wq??x,K=Wk??x,V=Wv??x -
分割多頭
將每個矩陣分割為 hhh 個頭部:
Q→[Q1,Q2,...,Qh],每個Qi∈Rdk\text{Q} \rightarrow [Q_1, Q_2, ..., Q_h], \quad \text{每個} Q_i \in \mathbb{R}^{d_k}Q→[Q1?,Q2?,...,Qh?],每個Qi?∈Rdk? -
計算注意力分數
對每個頭部計算縮放點積注意力:
Attention(Qi,Ki,Vi)=softmax(QiKiTdk)Vi\text{Attention}(Q_i, K_i, V_i) = \text{softmax}\left(\frac{Q_iK_i^T}{\sqrt{d_k}}\right)V_iAttention(Qi?,Ki?,Vi?)=softmax(dk??Qi?KiT??)Vi? -
合并多頭
拼接所有頭部的輸出并通過線性層:
MultiHead=Wo?[head1;head2;...;headh]\text{MultiHead} = W_o·[\text{head}_1; \text{head}_2; ... ; \text{head}_h]MultiHead=Wo??[head1?;head2?;...;headh?]
數學原理:
多頭注意力允許模型同時關注不同表示子空間的信息:
MultiHead(Q,K,V)=Concat(head1,...,headh)WO\text{MultiHead}(Q,K,V) = \text{Concat}(\text{head}_1, ..., \text{head}_h)W^OMultiHead(Q,K,V)=Concat(head1?,...,headh?)WO
其中每個頭的計算為:
headi=Attention(QWiQ,KWiK,VWiV)\text{head}_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V)headi?=Attention(QWiQ?,KWiK?,VWiV?)
以下是一個使用 PyTorch 實現的多頭自注意力 (Multi-head Self Attention) 代碼示例,包含詳細注釋說明:
import torch
import torch.nn as nn
import torch.nn.functional as Fclass MultiHeadAttention(nn.Module):def __init__(self, embed_dim, num_heads):"""embed_dim: 輸入向量維度num_heads: 注意力頭的數量"""super().__init__()self.embed_dim = embed_dimself.num_heads = num_headsself.head_dim = embed_dim // num_heads # 每個頭的維度# 檢查維度是否可整除assert self.head_dim * num_heads == embed_dim# 定義線性變換層self.query = nn.Linear(embed_dim, embed_dim)self.key = nn.Linear(embed_dim, embed_dim)self.value = nn.Linear(embed_dim, embed_dim)self.fc_out = nn.Linear(embed_dim, embed_dim)def forward(self, x):"""x: 輸入張量,形狀為 (batch_size, seq_len, embed_dim)"""batch_size = x.shape[0] #[4,10,512]# 1. 線性投影Q = self.query(x) # (batch_size, seq_len, embed_dim) #[4,10,512]K = self.key(x) # (batch_size, seq_len, embed_dim) #[4,10,512]V = self.value(x) # (batch_size, seq_len, embed_dim) #[4,10,512]# 2. 分割多頭Q = Q.view(batch_size, -1, self.num_heads, self.head_dim).permute(0, 2, 1, 3) #[4,8,10,64]K = K.view(batch_size, -1, self.num_heads, self.head_dim).permute(0, 2, 1, 3) #[4,8,10,64]V = V.view(batch_size, -1, self.num_heads, self.head_dim).permute(0, 2, 1, 3) #[4,8,10,64]# 現在形狀: (batch_size, num_heads, seq_len, head_dim)# 3. 計算注意力分數# 計算 Q·K^T / sqrt(d_k)energy = torch.matmul(Q, K.permute(0, 1, 3, 2)) / (self.head_dim ** 0.5) #[4,8,10,64]* #[4,8,64,10] = [4,8,10,10]# 形狀: (batch_size, num_heads, seq_len, seq_len)# 4. 應用softmax獲取注意力權重attention = F.softmax(energy, dim=-1)# 形狀: (batch_size, num_heads, seq_len, seq_len)# 5. 計算加權和out = torch.matmul(attention, V)#[4,8,10,10]* [4,8,10,64] = [4,8,10,64]# 形狀: (batch_size, num_heads, seq_len, head_dim)# 6. 合并多頭out = out.permute(0, 2, 1, 3).contiguous()out = out.view(batch_size, -1, self.embed_dim)# 形狀: (batch_size, seq_len, embed_dim)# 7. 最終線性變換out = self.fc_out(out)return out# 使用示例
if __name__ == "__main__":# 參數設置embed_dim = 512 # 輸入維度num_heads = 8 # 注意力頭數seq_len = 10 # 序列長度batch_size = 4 # 批大小# 創建多頭注意力模塊mha = MultiHeadAttention(embed_dim, num_heads)# 生成模擬輸入數據input_data = torch.randn(batch_size, seq_len, embed_dim)# 前向傳播output = mha(input_data)print("輸入形狀:", input_data.shape)print("輸出形狀:", output.shape)
輸出示例:
輸入形狀: torch.Size([4, 10, 512])
輸出形狀: torch.Size([4, 10, 512])
此實現保持了輸入輸出維度一致,可直接集成到Transformer等架構中。