第8講、Multi-Head Attention 的核心機制與實現細節

🤔 為什么要有 Multi-Head Attention?

單個 Attention 機制雖然可以捕捉句子中不同詞之間的關系,但它只能關注一種角度或模式

Multi-Head 的作用是:

多個頭 = 多個視角同時觀察序列的不同關系

例如:

  • 一個頭可能專注主語和動詞的關系;
  • 另一個頭可能專注賓語和介詞;
  • 還有的可能學習句法結構或時態變化。

這些頭的表示最終會被拼接(concatenate)后再線性變換整合成更豐富的上下文表示。

🔍 技術深入:Multi-Head Attention 計算過程

Multi-Head Attention 的計算過程如下:

  1. 對輸入 X 進行線性變換得到 Q、K、V 矩陣
  2. 將 Q、K、V 分割成 h 個頭
  3. 每個頭獨立計算 Attention
  4. 拼接所有頭的輸出
  5. 最后進行一次線性變換
# 偽代碼實現
def multi_head_attention(X, h=8):# 線性變換獲得 Q, K, VQ = X @ W_q  # [batch_size, seq_len, d_model]K = X @ W_kV = X @ W_v# 分割成多頭Q_heads = split_heads(Q, h)  # [batch_size, h, seq_len, d_k]K_heads = split_heads(K, h)V_heads = split_heads(V, h)# 每個頭獨立計算 attentionattn_outputs = []for i in range(h):attn_output = scaled_dot_product_attention(Q_heads[:, i], K_heads[:, i], V_heads[:, i])attn_outputs.append(attn_output)# 拼接所有頭的輸出concat_output = concatenate(attn_outputs)  # [batch_size, seq_len, d_model]# 最后的線性變換output = concat_output @ W_oreturn output

🧮 如何判斷多少個頭(h)?

Transformer 默認將 d_model(模型維度)均分給每個頭。

設:

  • d_model = 512:模型的總嵌入維度
  • h = 8:頭數

那么每個頭的維度為:

d_k = d_model // h = 512 // 8 = 64

一般要求:

?? d_model 必須能被 h 整除。

📊 參數計算

Multi-Head Attention 中的參數量:

  • 輸入投影矩陣:3 × (d_model × d_model) = 3d_model2
  • 輸出投影矩陣:d_model × d_model = d_model2

總參數量:4 × d_model2

例如,當 d_model = 512 時,參數量約為 100 萬。


📌 頭的數量怎么選?

頭數 h每頭維度 d_k適用情境
1全部基線,最弱(沒多視角)
4中等小模型,如 tiny Transformer
864標準配置,如原始 Transformer
16更細粒度大模型中常見,如 BERT-large

實際訓練中:

  • 小任務(toy 或翻譯教學):用 2 或 4 個頭就夠了。
  • 真實 NLP 任務:建議使用 8 個頭(Transformer-base 規范)。
  • 太多頭而模型參數不足時,效果可能反而下降(每頭維度太小)。

📈 頭數與性能關系

研究表明,頭數與模型性能并非簡單的線性關系:

  • 頭數過少:無法捕捉多種語言模式
  • 頭數適中:性能最佳
  • 頭數過多:每個頭的維度變小,表達能力下降

外鏈圖片轉存失敗,源站可能有防盜鏈機制,建議將圖片保存下來直接上傳

🔬 實驗發現

Michel et al. (2019) 的研究《Are Sixteen Heads Really Better than One?》發現:

  1. 在訓練好的模型中,并非所有頭都同等重要
  2. 大多數情況下,可以剪枝掉一部分頭而不顯著影響性能
  3. 不同層的頭有不同的作用,底層頭和頂層頭往往更為重要

💡 Multi-Head Attention 的優勢

  1. 并行計算:所有頭可以并行計算,提高訓練效率
  2. 多角度表示:捕捉不同類型的依賴關系
  3. 信息冗余:多頭提供冗余信息,增強模型魯棒性
  4. 注意力分散:防止單一頭過度關注某些模式

🧠 總結一句話

Multi-Head 的本質是多角度捕捉詞與詞的關系,提升模型對上下文的理解能力。頭數越多,觀察角度越多,但每個頭的維度會減小,需注意平衡。


📊 Attention 可視化

不同頭學習到的注意力模式各不相同。以下是一個英語句子在 8 頭注意力機制下的可視化示例:

外鏈圖片轉存失敗,源站可能有防盜鏈機制,建議將圖片保存下來直接上傳

可以看到:

  • 頭1:關注相鄰詞的關系
  • 頭2:捕捉主語-謂語關系
  • 頭3:識別句法結構
  • 頭4:連接相關實體
  • 其他頭:各自專注于不同的語言特征

這種多角度的觀察使得 Transformer 能夠全面理解文本的語義和結構。


🖥? Streamlit 交互式可視化案例

想要直觀地理解 Multi-Head Attention?以下是一個使用 Streamlit 構建的交互式可視化案例,讓你可以實時探索不同頭的注意力模式:

import streamlit as st
import numpy as np
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import seaborn as sns
from transformers import BertTokenizer, BertModel# 頁面設置
st.set_page_config(page_title="Multi-Head Attention 可視化", layout="wide")
st.title("Multi-Head Attention 可視化工具")# 加載預訓練模型
@st.cache_resource
def load_model():tokenizer = BertTokenizer.from_pretrained('bert-base-chinese')model = BertModel.from_pretrained('bert-base-chinese', output_attentions=True)return tokenizer, modeltokenizer, model = load_model()# 用戶輸入
user_input = st.text_area("請輸入一段文本進行分析:", "Transformer是一種強大的神經網絡架構,它使用了Multi-Head Attention機制。",height=100)# 處理文本
if user_input:# 分詞并獲取注意力權重inputs = tokenizer(user_input, return_tensors="pt")outputs = model(**inputs)# 獲取所有層的注意力權重attentions = outputs.attentions  # tuple of tensors, one per layer# 選擇層layer_idx = st.slider("選擇Transformer層:", 0, len(attentions)-1, 0)# 獲取選定層的注意力權重layer_attentions = attentions[layer_idx].detach().numpy()# 獲取頭數num_heads = layer_attentions.shape[1]# 選擇頭head_idx = st.slider("選擇注意力頭:", 0, num_heads-1, 0)# 獲取選定頭的注意力權重head_attention = layer_attentions[0, head_idx]# 獲取標記tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"][0])# 可視化fig, ax = plt.subplots(figsize=(10, 8))sns.heatmap(head_attention, xticklabels=tokens, yticklabels=tokens, cmap="YlGnBu", ax=ax)plt.title(f"第 {layer_idx+1} 層,第 {head_idx+1} 個頭的注意力權重")st.pyplot(fig)# 顯示注意力模式分析st.subheader("注意力模式分析")# 計算每個詞的平均注意力avg_attention = head_attention.mean(axis=0)top_indices = np.argsort(avg_attention)[-3:][::-1]st.write("這個注意力頭主要關注的詞:")for idx in top_indices:st.write(f"- {tokens[idx]}: {avg_attention[idx]:.4f}")# 添加交互式功能if st.checkbox("顯示所有頭的對比"):st.subheader("所有頭的注意力對比")# 為每個頭創建一個小型熱力圖# 計算行列數以適應任意數量的頭num_cols = 4num_rows = (num_heads + num_cols - 1) // num_cols  # 向上取整fig, axes = plt.subplots(num_rows, num_cols, figsize=(15, 3*num_rows))axes = axes.flatten()for h in range(num_heads):sns.heatmap(layer_attentions[0, h], xticklabels=[] if h < (num_heads-num_cols) else tokens, yticklabels=[] if h % num_cols != 0 else tokens, cmap="YlGnBu", ax=axes[h])axes[h].set_title(f"頭 {h+1}")# 隱藏未使用的子圖for h in range(num_heads, len(axes)):axes[h].axis('off')plt.tight_layout()st.pyplot(fig)# 添加解釋st.markdown("""### 如何解讀這個可視化:- 顏色越深表示注意力權重越高- 縱軸代表查詢詞(當前詞)- 橫軸代表鍵詞(被關注的詞)- 每個頭學習不同的關注模式通過調整滑塊,你可以探索不同層和不同頭的注意力模式,觀察模型如何理解文本中的關系。""")# 運行說明
st.sidebar.markdown("""
## 使用說明1. 在文本框中輸入你想分析的文本
2. 使用滑塊選擇要查看的層和注意力頭
3. 查看熱力圖了解詞與詞之間的注意力關系
4. 勾選"顯示所有頭的對比"可以同時查看所有頭的模式這個工具幫助你直觀理解 Multi-Head Attention 的工作原理和不同頭的功能分工。
""")# 代碼說明
with st.expander("查看完整代碼實現"):st.code("""
import streamlit as st
import numpy as np
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import seaborn as sns
from transformers import BertTokenizer, BertModel# 頁面設置
st.set_page_config(page_title="Multi-Head Attention 可視化", layout="wide")
st.title("Multi-Head Attention 可視化工具")# 加載預訓練模型
@st.cache_resource
def load_model():tokenizer = BertTokenizer.from_pretrained('bert-base-chinese')model = BertModel.from_pretrained('bert-base-chinese', output_attentions=True)return tokenizer, modeltokenizer, model = load_model()# 用戶輸入和可視化邏輯
# ...此處省略,與上面代碼相同
""")### 🚀 如何運行這個可視化工具1. 安裝必要的依賴:
```bash
pip install streamlit torch transformers matplotlib seaborn
  1. 將上面的代碼保存為 attention_viz.py

  2. 運行 Streamlit 應用:

streamlit run attention_viz.py


這個交互式工具讓你可以:

  • 輸入任意文本并查看注意力分布
  • 選擇不同的 Transformer 層和注意力頭
  • 直觀對比不同頭學習到的不同模式
  • 分析哪些詞獲得了最高的注意力權重

通過這個可視化工具,你可以親自探索 Multi-Head Attention 的工作原理,加深對這一機制的理解。

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/diannao/83442.shtml
繁體地址,請注明出處:http://hk.pswp.cn/diannao/83442.shtml
英文地址,請注明出處:http://en.pswp.cn/diannao/83442.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

百度智能云千帆攜手聯想,共創MCP生態宇宙

5月7日&#xff0c;2025聯想創新科技大會&#xff08;Tech World&#xff09;在上海世博中心舉行&#xff0c;本屆大會以“讓AI成為創新生產力”為主題。會上&#xff0c;聯想集團董事長兼CEO楊元慶展示了包括覆蓋全場景的超級智能體矩陣&#xff0c;包括個人超級智能體、企業超…

【OpenCV】幀差法、級聯分類器、透視變換

一、幀差法&#xff08;移動目標識別&#xff09;&#xff1a; 好處&#xff1a;開銷小&#xff0c;不怎么消耗CPU的算力&#xff0c;對硬件要求不高&#xff0c;但只適合固定攝像頭 1、優點 計算效率高&#xff0c;硬件要求 響應速度快&#xff0c;實時性強 直接利用連續幀…

數據庫遷移的藝術:團隊協作中的沖突預防與解決之道

title: 數據庫遷移的藝術:團隊協作中的沖突預防與解決之道 date: 2025/05/17 00:13:50 updated: 2025/05/17 00:13:50 author: cmdragon excerpt: 在團隊協作中,數據庫遷移腳本沖突是常見問題。通過Alembic工具,可以有效地管理和解決這些沖突。沖突預防的四原則包括功能分…

Linux常用命令43——bunzip2解壓縮bz2文件

在使用Linux或macOS日常開發中&#xff0c;熟悉一些基本的命令有助于提高工作效率&#xff0c;bunzip2可解壓縮.bz2格式的壓縮文件。bunzip2實際上是bzip2的符號連接&#xff0c;執行bunzip2與bzip2 -d的效果相同。本篇學習記錄bunzip2命令的基本使用。 首先查看幫助文檔&#…

盲盒:拆開未知的驚喜,收藏生活的儀式感

一、什么是盲盒&#xff1f;—— 一場關于“未知”的浪漫冒險 盲盒&#xff0c;是一種充滿神秘感的消費體驗&#xff1a; &#x1f381; 盒中藏驚喜——每個盲盒外觀相同&#xff0c;但內含隨機商品&#xff0c;可能是普通款、稀有款&#xff0c;甚至是“隱藏款”&#xff1b;…

Android 中使用通知(Kotlin 版)

1. 前置條件 Android Studio&#xff1a;確保使用最新版本&#xff08;2023.3.1&#xff09;目標 API&#xff1a;最低 API 21&#xff0c;兼容 Android 8.0&#xff08;渠道&#xff09;和 13&#xff08;權限&#xff09;依賴庫&#xff1a;使用 WorkManager 和 Notificatio…

使用大模型預測急性結石性疾病技術方案

目錄 1. 數據預處理與特征工程偽代碼 - 數據清洗與特征處理數據預處理流程圖2. 大模型構建與訓練偽代碼 - 模型訓練模型訓練流程圖3. 術前預測系統偽代碼 - 術前風險評估術前預測流程圖4. 術中實時調整系統偽代碼 - 術中風險預警術中調整流程圖5. 術后護理系統偽代碼 - 并發癥預…

每日Prompt:生成自拍照

提示詞 幫我生成一張圖片&#xff1a;圖片風格為「人像攝影」&#xff0c;請你畫一張及其平凡無奇的iPhone對鏡自拍照&#xff0c;主角是穿著JK風格cos服的可愛女孩&#xff0c;在自己精心布置的可按風格的房間內的落地鏡前用后置攝像頭隨手一拍的快照。照片開啟了閃光燈&…

動態規劃-64.最小路徑和-力扣(LetCode)

一、題目解析 從左上角到右下角使得數字總和最小且只能向下或向右移動 二、算法原理 1.狀態表示 我們需要求到達[i,j]位置時數字總和的最小值&#xff0c;所以dp[i][j]表示&#xff1a;到達[i,j]位置時&#xff0c;路徑數字總和的最小值。 2.狀態轉移方程 到達[i,j]之前要先…

LeetCode LCR 010 和為 K 的子數組 (Java)

兩種解法詳解&#xff1a;暴力枚舉與前綴和哈希表尋找和為k的子數組 在解決數組中和為k的連續子數組個數的問題時&#xff0c;我們可以采用不同的方法。本文將詳細解析兩種常見的解法&#xff1a;暴力枚舉法和前綴和結合哈希表的方法&#xff0c;分析它們的思路、優缺點及適用…

OpenVLA (2) 機器人環境和環境數據

文章目錄 [TOC](文章目錄) 前言1 BridgeData V21.1 概述1.2 硬件環境 2 數據集2.1 場景與結構2.2 數據結構2.2.1 images02.2.2 obs_dict.pkl2.2.3 policy_out.pkl 3 close question3.1 英偉達環境3.2 LIBERO 環境更適合仿真3.3 4090 運行問題 前言 按照筆者之前的行業經驗, 數…

深度學習(第3章——亞像素卷積和可形變卷積)

前言&#xff1a; 本章介紹了計算機識別超分領域和目標檢測領域中常常使用的兩種卷積變體&#xff0c;亞像素卷積&#xff08;Subpixel Convolution&#xff09;和可形變卷積&#xff08;Deformable Convolution&#xff09;&#xff0c;并給出對應pytorch的使用。 亞像素卷積…

大模型在腰椎間盤突出癥預測與治療方案制定中的應用研究

目錄 一、引言 1.1 研究背景 1.2 研究目的與意義 二、腰椎間盤突出癥概述 2.1 定義與病因 2.2 癥狀與診斷方法 2.3 治療方法概述 三、大模型技術原理與應用基礎 3.1 大模型的基本原理 3.2 大模型在醫療領域的應用現狀 3.3 用于腰椎間盤突出癥預測的可行性分析 四、…

Vue3學習(組合式API——ref模版引用與defineExpose編譯宏函數)

目錄 一、ref模版引用。 &#xff08;1&#xff09;基本介紹。 &#xff08;2&#xff09;核心基本步驟。(以獲取DOM、組件為例) &#xff08;3&#xff09;案例&#xff1a;獲取dom對象演示。 <1>需求&#xff1a;點擊按鈕&#xff0c;讓輸入框聚焦。 &#xff08;4&…

公鏈開發及其配套設施:錢包與區塊鏈瀏覽器

公鏈開發及其配套設施&#xff1a;錢包與區塊鏈瀏覽器的技術架構與生態實踐 ——2025年區塊鏈基礎設施建設的核心邏輯與創新突破 一、公鏈開發&#xff1a;構建去中心化世界的基石 1. 技術架構設計的三重挑戰 公鏈作為開放的區塊鏈網絡&#xff0c;需在性能、安全性與去中心…

Kotlin 作用域函數(let、run、with、apply、also)對比

Kotlin 的 作用域函數&#xff08;Scope Functions&#xff09; 是簡化代碼邏輯的重要工具&#xff0c;它們通過臨時作用域為對象提供更簡潔的操作方式。以下是 let、run、with、apply、also 的對比分析&#xff1a; 一、核心區別對比表 函數上下文對象引用返回值是否擴展函數…

14、Python時間表示:Unix時間戳、毫秒微秒精度與time模塊實戰

適合人群&#xff1a;零基礎自學者 | 編程小白快速入門 閱讀時長&#xff1a;約5分鐘 文章目錄 一、問題&#xff1a;計算機中的時間的表示、Unix時間點&#xff1f;1、例子1&#xff1a;計算機的“生日”&#xff1a;Unix時間點2、答案&#xff1a;&#xff08;1&#xff09;U…

AI日報 - 2024年5月17日

&#x1f31f; 今日概覽 (60秒速覽) ▎&#x1f916; 大模型前沿 | OpenAI推出自主編碼代理Codex&#xff1b;Google DeepMind發布Gemini驅動的編碼代理AlphaEvolve&#xff0c;能設計先進算法&#xff1b;Meta旗艦AI模型Llama 4 Behemoth發布推遲。 Codex能并行處理多任務&…

DriveMM:用于自動駕駛的一體化大型多模態模型——論文閱讀

《DriveMM: All-in-One Large Multimodal Model for Autonomous Driving》2024年12月發表&#xff0c;來自中山大學深圳分校和美團的論文。 大型多模態模型&#xff08;LMM&#xff09;通過整合大型語言模型&#xff0c;在自動駕駛&#xff08;AD&#xff09;中表現出卓越的理解…

C++_STL_map與set

1. 關聯式容器 在初階階段&#xff0c;我們已經接觸過STL中的部分容器&#xff0c;比如&#xff1a;vector、list、deque、 forward_list(C11)等&#xff0c;這些容器統稱為序列式容器&#xff0c;因為其底層為線性序列的數據結構&#xff0c;里面 存儲的是元素本身。那什么是…