本文帶你一步步理解 Transformer 中最核心的模塊:多頭注意力機制(Multi-Head Attention)。從原理到實現,配圖 + 舉例 + PyTorch 代碼,一次性說清楚!
什么是 Multi-Head Attention?
簡單說,多頭注意力就是一種讓模型在多個角度“看”一個序列的機制。
在自然語言中,一個詞的含義往往依賴于上下文,比如:
“我把蘋果給了她”
模型在處理“蘋果”時,需要關注“我”“她”“給了”等詞,多頭注意力就是這樣一種機制——從多個角度理解上下文關系。
Self-Attention 是什么?為什么還要多頭?
在講“多頭”之前,咱們先回顧一下基礎的 Self-Attention。
Self-Attention(自注意力)機制的目標是:
讓每個詞都能“關注”整個句子里的其他詞,融合上下文。
它的核心步驟是:
-
對每個詞生成 Query、Key、Value 向量
-
用 Query 和所有 Key 做點積,算出每個詞對其他詞的關注度(打分)
-
用 Softmax 得到權重,對 Value 加權平均,生成當前詞的新表示
這樣做的好處是:詞的語義表示不再是孤立的,而是上下文相關的。
Self-Attention vs Multi-Head Attention
但問題是——單頭 Self-Attention 視角有限。就像一個老師只能從一種角度講課。
于是,Multi-Head Attention 應運而生!
特性 | Self-Attention(單頭) | Multi-Head Attention(多頭) |
---|---|---|
輸入映射矩陣 | 一組 Q/K/V 線性變換 | 多組 Q/K/V,每個頭一組 |
學習角度 | 單一視角 | 多角度并行理解 |
表達能力 | 有限 | 更豐富、強大 |
結構 | 簡單 | 并行多個頭 + 合并輸出 |
一句話總結:
Multi-Head Attention = 多個不同“視角”的 Self-Attention 并行處理 + 合并結果
?多頭注意力:8個腦袋一起思考!
多頭 = 多個“單頭注意力”并行處理!
每個頭使用不同的線性變換矩陣,所以能從不同視角處理數據:
-
第1個頭可能專注短依賴(like 動詞和主語)
-
第2個頭可能專注實體關系(我 vs 她)
-
第3個頭可能關注時間順序(“給了”前后)
-
……共用同一個輸入,學習到不同特征!
多頭的步驟:
-
將輸入向量(如512維)拆成多個頭(比如8個,每個64維)
-
每個頭獨立進行 attention
-
所有頭的輸出拼接
-
再過一次線性變換,融合成最終輸出
?PyTorch 實現(簡潔版)
我們來看下 PyTorch 中的簡化實現:
import torch
import torch.nn as nn
import copydef clones(module, N):return nn.ModuleList([copy.deepcopy(module) for _ in range(N)])def attention(query, key, value, mask=None, dropout=None):d_k = query.size(-1)scores = torch.matmul(query, key.transpose(-2, -1)) / (d_k ** 0.5)if mask is not None:scores = scores.masked_fill(mask == 0, -1e9)p_attn = torch.softmax(scores, dim=-1)if dropout:p_attn = dropout(p_attn)return torch.matmul(p_attn, value), p_attnclass MultiHeadedAttention(nn.Module):def __init__(self, h, d_model, dropout=0.1):super().__init__()assert d_model % h == 0self.d_k = d_model // hself.h = hself.linears = clones(nn.Linear(d_model, d_model), 4)self.dropout = nn.Dropout(dropout)def forward(self, query, key, value, mask=None):if mask is not None:mask = mask.unsqueeze(1)nbatches = query.size(0)query, key, value = [lin(x).view(nbatches, -1, self.h, self.d_k).transpose(1, 2)for lin, x in zip(self.linears, (query, key, value))]x, self.attn = attention(query, key, value, mask=mask, dropout=self.dropout)x = x.transpose(1, 2).contiguous().view(nbatches, -1, self.h * self.d_k)return self.linears[-1](x)
舉個例子:多頭在實際模型中的作用
假設輸入是句子:
"The animal didn't cross the street because it was too tired."
多頭注意力的不同頭可能會:
-
🧠 頭1:關注“animal”和“it”之間的指代關系;
-
📐 頭2:識別“because”和“tired”之間的因果聯系;
-
📚 頭3:注意句子的結構層次……
所以說,多頭注意力本質上是一個“并行注意力專家系統”!
?總結
項目 | 解釋 |
---|---|
目的 | 提升模型表達能力,從多個角度理解輸入 |
核心機制 | 將向量分頭 → 每頭獨立 attention → 合并輸出 |
技術關鍵 | view , transpose , matmul , softmax , 拼接線性層 |
推薦學習路徑
-
🔹 理解 Self-Attention 的點積公式
-
🔹 搞懂
view
,transpose
等張量操作 -
🔹 看 Transformer 整體結構,關注每層作用