🤔 為什么要有 Multi-Head Attention?
單個 Attention 機制雖然可以捕捉句子中不同詞之間的關系,但它只能關注一種角度或模式。
Multi-Head 的作用是:
多個頭 = 多個視角同時觀察序列的不同關系。
例如:
- 一個頭可能專注主語和動詞的關系;
- 另一個頭可能專注賓語和介詞;
- 還有的可能學習句法結構或時態變化。
這些頭的表示最終會被拼接(concatenate)后再線性變換整合成更豐富的上下文表示。
🔍 技術深入:Multi-Head Attention 計算過程
Multi-Head Attention 的計算過程如下:
- 對輸入 X 進行線性變換得到 Q、K、V 矩陣
- 將 Q、K、V 分割成 h 個頭
- 每個頭獨立計算 Attention
- 拼接所有頭的輸出
- 最后進行一次線性變換
# 偽代碼實現
def multi_head_attention(X, h=8):# 線性變換獲得 Q, K, VQ = X @ W_q # [batch_size, seq_len, d_model]K = X @ W_kV = X @ W_v# 分割成多頭Q_heads = split_heads(Q, h) # [batch_size, h, seq_len, d_k]K_heads = split_heads(K, h)V_heads = split_heads(V, h)# 每個頭獨立計算 attentionattn_outputs = []for i in range(h):attn_output = scaled_dot_product_attention(Q_heads[:, i], K_heads[:, i], V_heads[:, i])attn_outputs.append(attn_output)# 拼接所有頭的輸出concat_output = concatenate(attn_outputs) # [batch_size, seq_len, d_model]# 最后的線性變換output = concat_output @ W_oreturn output
🧮 如何判斷多少個頭(h
)?
Transformer 默認將 d_model
(模型維度)均分給每個頭。
設:
d_model = 512
:模型的總嵌入維度h = 8
:頭數
那么每個頭的維度為:
d_k = d_model // h = 512 // 8 = 64
一般要求:
??
d_model
必須能被h
整除。
📊 參數計算
Multi-Head Attention 中的參數量:
- 輸入投影矩陣:3 × (d_model × d_model) = 3d_model2
- 輸出投影矩陣:d_model × d_model = d_model2
總參數量:4 × d_model2
例如,當 d_model = 512 時,參數量約為 100 萬。
📌 頭的數量怎么選?
頭數 h | 每頭維度 d_k | 適用情境 |
---|---|---|
1 | 全部 | 基線,最弱(沒多視角) |
4 | 中等 | 小模型,如 tiny Transformer |
8 | 64 | 標準配置,如原始 Transformer |
16 | 更細粒度 | 大模型中常見,如 BERT-large |
實際訓練中:
- 小任務(toy 或翻譯教學):用 2 或 4 個頭就夠了。
- 真實 NLP 任務:建議使用 8 個頭(Transformer-base 規范)。
- 太多頭而模型參數不足時,效果可能反而下降(每頭維度太小)。
📈 頭數與性能關系
研究表明,頭數與模型性能并非簡單的線性關系:
- 頭數過少:無法捕捉多種語言模式
- 頭數適中:性能最佳
- 頭數過多:每個頭的維度變小,表達能力下降
🔬 實驗發現
Michel et al. (2019) 的研究《Are Sixteen Heads Really Better than One?》發現:
- 在訓練好的模型中,并非所有頭都同等重要
- 大多數情況下,可以剪枝掉一部分頭而不顯著影響性能
- 不同層的頭有不同的作用,底層頭和頂層頭往往更為重要
💡 Multi-Head Attention 的優勢
- 并行計算:所有頭可以并行計算,提高訓練效率
- 多角度表示:捕捉不同類型的依賴關系
- 信息冗余:多頭提供冗余信息,增強模型魯棒性
- 注意力分散:防止單一頭過度關注某些模式
🧠 總結一句話
Multi-Head 的本質是多角度捕捉詞與詞的關系,提升模型對上下文的理解能力。頭數越多,觀察角度越多,但每個頭的維度會減小,需注意平衡。
📊 Attention 可視化
不同頭學習到的注意力模式各不相同。以下是一個英語句子在 8 頭注意力機制下的可視化示例:
可以看到:
- 頭1:關注相鄰詞的關系
- 頭2:捕捉主語-謂語關系
- 頭3:識別句法結構
- 頭4:連接相關實體
- 其他頭:各自專注于不同的語言特征
這種多角度的觀察使得 Transformer 能夠全面理解文本的語義和結構。
🖥? Streamlit 交互式可視化案例
想要直觀地理解 Multi-Head Attention?以下是一個使用 Streamlit 構建的交互式可視化案例,讓你可以實時探索不同頭的注意力模式:
import streamlit as st
import numpy as np
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import seaborn as sns
from transformers import BertTokenizer, BertModel# 頁面設置
st.set_page_config(page_title="Multi-Head Attention 可視化", layout="wide")
st.title("Multi-Head Attention 可視化工具")# 加載預訓練模型
@st.cache_resource
def load_model():tokenizer = BertTokenizer.from_pretrained('bert-base-chinese')model = BertModel.from_pretrained('bert-base-chinese', output_attentions=True)return tokenizer, modeltokenizer, model = load_model()# 用戶輸入
user_input = st.text_area("請輸入一段文本進行分析:", "Transformer是一種強大的神經網絡架構,它使用了Multi-Head Attention機制。",height=100)# 處理文本
if user_input:# 分詞并獲取注意力權重inputs = tokenizer(user_input, return_tensors="pt")outputs = model(**inputs)# 獲取所有層的注意力權重attentions = outputs.attentions # tuple of tensors, one per layer# 選擇層layer_idx = st.slider("選擇Transformer層:", 0, len(attentions)-1, 0)# 獲取選定層的注意力權重layer_attentions = attentions[layer_idx].detach().numpy()# 獲取頭數num_heads = layer_attentions.shape[1]# 選擇頭head_idx = st.slider("選擇注意力頭:", 0, num_heads-1, 0)# 獲取選定頭的注意力權重head_attention = layer_attentions[0, head_idx]# 獲取標記tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"][0])# 可視化fig, ax = plt.subplots(figsize=(10, 8))sns.heatmap(head_attention, xticklabels=tokens, yticklabels=tokens, cmap="YlGnBu", ax=ax)plt.title(f"第 {layer_idx+1} 層,第 {head_idx+1} 個頭的注意力權重")st.pyplot(fig)# 顯示注意力模式分析st.subheader("注意力模式分析")# 計算每個詞的平均注意力avg_attention = head_attention.mean(axis=0)top_indices = np.argsort(avg_attention)[-3:][::-1]st.write("這個注意力頭主要關注的詞:")for idx in top_indices:st.write(f"- {tokens[idx]}: {avg_attention[idx]:.4f}")# 添加交互式功能if st.checkbox("顯示所有頭的對比"):st.subheader("所有頭的注意力對比")# 為每個頭創建一個小型熱力圖# 計算行列數以適應任意數量的頭num_cols = 4num_rows = (num_heads + num_cols - 1) // num_cols # 向上取整fig, axes = plt.subplots(num_rows, num_cols, figsize=(15, 3*num_rows))axes = axes.flatten()for h in range(num_heads):sns.heatmap(layer_attentions[0, h], xticklabels=[] if h < (num_heads-num_cols) else tokens, yticklabels=[] if h % num_cols != 0 else tokens, cmap="YlGnBu", ax=axes[h])axes[h].set_title(f"頭 {h+1}")# 隱藏未使用的子圖for h in range(num_heads, len(axes)):axes[h].axis('off')plt.tight_layout()st.pyplot(fig)# 添加解釋st.markdown("""### 如何解讀這個可視化:- 顏色越深表示注意力權重越高- 縱軸代表查詢詞(當前詞)- 橫軸代表鍵詞(被關注的詞)- 每個頭學習不同的關注模式通過調整滑塊,你可以探索不同層和不同頭的注意力模式,觀察模型如何理解文本中的關系。""")# 運行說明
st.sidebar.markdown("""
## 使用說明1. 在文本框中輸入你想分析的文本
2. 使用滑塊選擇要查看的層和注意力頭
3. 查看熱力圖了解詞與詞之間的注意力關系
4. 勾選"顯示所有頭的對比"可以同時查看所有頭的模式這個工具幫助你直觀理解 Multi-Head Attention 的工作原理和不同頭的功能分工。
""")# 代碼說明
with st.expander("查看完整代碼實現"):st.code("""
import streamlit as st
import numpy as np
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import seaborn as sns
from transformers import BertTokenizer, BertModel# 頁面設置
st.set_page_config(page_title="Multi-Head Attention 可視化", layout="wide")
st.title("Multi-Head Attention 可視化工具")# 加載預訓練模型
@st.cache_resource
def load_model():tokenizer = BertTokenizer.from_pretrained('bert-base-chinese')model = BertModel.from_pretrained('bert-base-chinese', output_attentions=True)return tokenizer, modeltokenizer, model = load_model()# 用戶輸入和可視化邏輯
# ...此處省略,與上面代碼相同
""")### 🚀 如何運行這個可視化工具1. 安裝必要的依賴:
```bash
pip install streamlit torch transformers matplotlib seaborn
-
將上面的代碼保存為
attention_viz.py
-
運行 Streamlit 應用:
streamlit run attention_viz.py
這個交互式工具讓你可以:
- 輸入任意文本并查看注意力分布
- 選擇不同的 Transformer 層和注意力頭
- 直觀對比不同頭學習到的不同模式
- 分析哪些詞獲得了最高的注意力權重
通過這個可視化工具,你可以親自探索 Multi-Head Attention 的工作原理,加深對這一機制的理解。