從代碼學習深度學習 - 多頭注意力 PyTorch 版

文章目錄

  • 前言
  • 一、多頭注意力機制介紹
    • 1.1 工作原理
    • 1.2 優勢
    • 1.3 代碼實現概述
  • 二、代碼解析
    • 2.1 導入依賴
      • 序列掩碼函數
    • 2.2 掩碼 Softmax 函數
    • 2.3 縮放點積注意力
    • 2.4 張量轉換函數
    • 2.5 多頭注意力模塊
    • 2.6 測試代碼
  • 總結


前言

在深度學習領域,注意力機制(Attention Mechanism)是自然語言處理(NLP)和計算機視覺(CV)等任務中的核心組件之一。特別是多頭注意力(Multi-Head Attention),作為 Transformer 模型的基礎,極大地提升了模型對復雜依賴關系的捕捉能力。本文通過分析一個完整的 PyTorch 實現,帶你深入理解多頭注意力的原理和代碼實現。我們將從代碼入手,逐步解析每個函數和類的功能,結合文字說明,讓你不僅能運行代碼,還能理解其背后的設計邏輯。無論你是初學者還是有一定經驗的開發者,這篇博客都將幫助你更直觀地掌握多頭注意力機制。
完整代碼:下載鏈接


一、多頭注意力機制介紹

多頭注意力(Multi-Head Attention)是 Transformer 模型的核心組件之一,廣泛應用于自然語言處理(NLP)、計算機視覺(CV)等領域。它通過并行運行多個注意力頭(Attention Heads),允許模型同時關注輸入序列中的不同部分,從而捕捉更豐富的語義和上下文依賴關系。相比單一的注意力機制,多頭注意力極大地增強了模型的表達能力,能夠處理復雜的模式和長距離依賴。

1.1 工作原理

多頭注意力的核心思想是將輸入的查詢(Queries)、鍵(Keys)和值(Values)通過線性變換映射到多個子空間,每個子空間由一個獨立的注意力頭處理。具體步驟如下:

  1. 線性變換:對輸入的查詢、鍵和值分別應用線性層,將其映射到隱藏維度(num_hiddens),并分割為多個頭的表示。
  2. 縮放點積注意力:每個注意力頭獨立計算縮放點積注意力(Scaled Dot-Product Attention),即通過查詢和鍵的點積計算注意力分數,再與值加權求和。
  3. 并行計算:多個注意力頭并行運行,每個頭關注輸入的不同方面,生成各自的輸出。
  4. 合并與變換:將所有頭的輸出拼接起來,并通過一個線性層融合,得到最終的多頭注意力輸出。

這種設計允許模型在不同子空間中學習不同的特征,例如在 NLP 任務中,一個頭可能關注句法結構,另一個頭可能關注語義關系。
在這里插入圖片描述

1.2 優勢

  • 多樣性:多頭機制使模型能夠從多個角度理解輸入,捕捉多樣化的模式。
  • 并行性:多頭計算可以高效并行化,提升計算效率。
  • 穩定性:通過縮放點積(除以特征維度的平方根),緩解了高維點積導致的數值不穩定問題。

1.3 代碼實現概述

在本文的實現中,我們使用 PyTorch 構建了一個完整的多頭注意力模塊,包含以下關鍵部分:

  • 序列掩碼:處理變長序列,屏蔽無效位置。
  • 縮放點積注意力:實現單個注意力頭的計算邏輯。
  • 張量轉換:通過 transpose_qkvtranspose_output 函數實現多頭分割與合并。
  • 多頭注意力類:整合所有組件,完成并行計算和輸出融合。

接下來的代碼解析將詳細展示這些部分的實現,幫助你從代碼層面深入理解多頭注意力的每一步計算邏輯。

二、代碼解析

以下是代碼的完整實現和詳細解析,代碼按照 Jupyter Notebook(在最開始給出了完整代碼下載鏈接) 的結構組織,并附上文字說明,幫助你理解每個部分的邏輯。

2.1 導入依賴

首先,我們導入必要的 Python 包,包括數學運算庫 math 和 PyTorch 的核心模塊 torchnn

# 導入包
import math
import torch
from torch import nn
  • math:用于計算縮放點積注意力中的歸一化因子(即特征維度的平方根)。
  • torch:PyTorch 的核心庫,提供張量運算和自動求導功能。
  • nn:PyTorch 的神經網絡模塊,包含 nn.Modulenn.Linear 等工具,用于構建神經網絡層。

序列掩碼函數

在處理序列數據(如句子)時,不同序列的長度可能不同,我們需要通過掩碼(Mask)來屏蔽無效位置,防止模型關注這些填充區域。以下是 sequence_mask 函數的實現:

def sequence_mask(X, valid_len, value=0):"""在序列中屏蔽不相關的項,使超出有效長度的位置被設置為指定值參數:X: 輸入張量,形狀 (batch_size, 最大序列長度, 特征維度) 或 (batch_size, 最大序列長度)valid_len: 有效長度張量,形狀 (batch_size,),表示每個序列的有效長度value: 屏蔽值,標量,默認值為 0,用于填充無效位置返回:輸出張量,形狀與輸入 X 相同,無效位置被設置為 value"""maxlen = X.size(1)  # 最大序列長度,標量# 創建掩碼,形狀 (1, 最大序列長度),與 valid_len 比較生成布爾張量,形狀 (batch_size, 最大序列長度)mask = torch.arange(maxlen, dtype=torch.float32, device=X.device)[None, :] < valid_len[:, None]# 將掩碼取反后,X 的無效位置被設置為 valueX[~mask] = valuereturn X

解析

  • 輸入
    • X:輸入張量,通常是序列數據,可能包含填充(padding)部分。
    • valid_len:每個樣本的有效長度,例如 [3, 2] 表示第一個樣本有 3 個有效 token,第二個樣本有 2 個。
    • value:用于填充無效位置的值,默認為 0。
  • 邏輯
    • maxlen 獲取序列的最大長度(即張量的第二維)。
    • torch.arange(maxlen) 創建一個從 0 到 maxlen-1 的序列,形狀為 (1, maxlen)
    • 通過廣播機制,與 valid_len(形狀 (batch_size, 1))比較,生成布爾掩碼 mask,形狀為 (batch_size, maxlen)
    • mask 表示哪些位置是有效的(True),哪些是無效的(False)。
    • 使用 ~mask 選擇無效位置,將其值設置為 value
  • 輸出:修改后的張量 X,無效位置被設置為 value,形狀不變。

作用:該函數用于在注意力計算中屏蔽填充區域,確保模型只關注有效 token。

2.2 掩碼 Softmax 函數

在注意力機制中,我們需要對注意力分數應用 Softmax 操作,將其轉換為概率分布。但由于序列長度不同,需要屏蔽無效位置的貢獻。以下是 masked_softmax 函數的實現:

import torch
import torch.nn.functional as Fdef masked_softmax(X, valid_lens):"""通過在最后一個軸上掩蔽元素來執行softmax操作,忽略無效位置參數:X: 輸入張量,形狀 (batch_size, 查詢個數, 鍵-值對個數),3D張量valid_lens: 有效長度張量,形狀 (batch_size,) 或 (batch_size, 查詢個數),1D或2D張量,表示每個序列的有效長度,即每個查詢可以參考的有效鍵值對長度返回:輸出張量,形狀 (batch_size, 查詢個數, 鍵-值對個數),softmax后的注意力權重"""if valid_lens is None:# 如果沒有有效長度,直接在最后一個軸上應用softmaxreturn F.softmax(X, dim=-1)shape 

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/bicheng/76794.shtml
繁體地址,請注明出處:http://hk.pswp.cn/bicheng/76794.shtml
英文地址,請注明出處:http://en.pswp.cn/bicheng/76794.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

學術版 GPT 網頁

學術版 GPT 網頁 1. 學術版 GPT 網頁非盈利版References https://academic.chatwithpaper.org/ 1. 學術版 GPT 網頁非盈利版 arXiv 全文翻譯&#xff0c;免費且無需登錄。 更換模型 System prompt: Serve me as a writing and programming assistant. 界面外觀 References …

MarkDown 輸出表格的方法

MarkDown用來輸出表格很簡單&#xff0c;比Word手搓表格簡單多了&#xff0c;而且方便修改。 MarkDown代碼&#xff1a; |A|B|C|D| |:-|-:|:-:|-| |1|b|c|d| |2|b|c|d| |3|b|c|d| |4|b|c|d| |5|b|c|d|顯示效果&#xff1a; ABCD1bcd2bcd3bcd4bcd5bcd A列強制左對齊&#xf…

MetaGPT深度解析:重塑AI協作開發的智能體框架實踐指南

一、框架架構與技術突破 1.1 系統架構設計 graph TBA[自然語言需求] --> B(需求解析引擎)B --> C{角色路由系統}C --> D[產品經理Agent]C --> E[架構師Agent]C --> F[工程師Agent]D --> G[PRD文檔]E --> H[架構圖]F --> I[代碼文件]G --> J[知識共…

自用:在使用SpringBoot做學生信息管理系統時遇到的問題

1、在做完查詢測試時&#xff0c;一直報出404找不到錯誤&#xff0c;原因是沒有為各個層的實現類添加注解 2、改完之后發現測試沒有數據&#xff0c;是因為我寫的返回值類型為空&#xff0c;應該返回一個List< Student > 3、我沒有想到要寫Result實體類&#xff0c;因為不…

SQLite + Redis = Redka

Redka 是一個基于 SQLite 實現的 Redis 替代產品&#xff0c;實現了 Redis 的核心功能&#xff0c;并且完全兼容 Redis API。它可以用于輕量級緩存、嵌入式系統、快速原型開發以及需要事務 ACID 特性的鍵值操作等場景。 功能特性 Redka 的主要特點包括&#xff1a; 使用 SQLi…

202529 | RocketMQ 簡介 + 安裝 + 集群搭建 + 消費模式 + 消費者組

RocketMQ簡介 RocketMQ 簡介 Apache RocketMQ 是一款開源的 分布式消息中間件&#xff08;Message Queue, MQ&#xff09;&#xff0c;由阿里巴巴團隊研發并捐贈給 Apache 基金會&#xff0c;現已成為頂級項目。它專為 高吞吐、低延遲、高可靠 的分布式場景設計&#xff0c;廣…

Go語言--語法基礎4--基本數據類型--整數類型

整型是所有編程語言里最基礎的數據類型。 Go 語言支持如下所示的這些整型類型。 需要注意的是&#xff0c; int 和 int32 在 Go 語言里被認為是兩種不同的類型&#xff0c;編譯器也不會幫你自動做類型轉換&#xff0c; 比如以下的例子會有編譯錯誤&#xff1a; var value2 in…

競拍商城:電商創新的博弈場與未來趨勢

競拍商城&#xff1a;電商創新的博弈場與未來趨勢 在傳統電商趨于同質化的今天&#xff0c;競拍商城憑借其獨特的交易機制和用戶激勵模式&#xff0c;成為電商領域的新寵。通過結合拍賣的博弈屬性與電商的便捷性&#xff0c;競拍商城不僅重塑了消費體驗&#xff0c;更催生了全…

Linux : 多線程互斥

目錄 一 前言 二 線程互斥 三 Mutex互斥量 1. 定義一個鎖&#xff08;造鎖&#xff09; 2. 初始化鎖 3. 上鎖 4. 解鎖 5. 摧毀鎖 四 鎖的使用 五 鎖的宏初始化 六 鎖的原理 1.如何看待鎖&#xff1f; 2. 如何理解加鎖和解鎖的本質 七 c封裝互斥鎖 八 可重入…

論文閱讀筆記——Reactive Diffusion Policy

RDP 論文 通過 AR 提供實時觸覺/力反饋&#xff1b;慢速擴散策略&#xff0c;用于預測低頻潛在空間中的高層動作分塊&#xff1b;快速非對稱分詞器實現閉環反饋控制。 ACT、 π 0 \pi_0 π0? 采取了動作分塊&#xff0c;在動作分塊執行期間處于開環狀態&#xff0c;無法及時響…

swagger 注釋說明

一、接口注釋核心字段 在 Go 的路由處理函數&#xff08;Handler&#xff09;上方添加注釋&#xff0c;支持以下常用注解&#xff1a; 注解名稱用途說明示例格式Summary接口簡要描述Summary 創建用戶Description接口詳細說明Description 通過用戶名和郵箱創建新用戶Tags接口分…

STM32 HAL庫 OLED驅動實現

一、概述 1.1 OLED 顯示屏簡介 OLED&#xff08;Organic Light - Emitting Diode&#xff09;即有機發光二極管&#xff0c;與傳統的 LCD 顯示屏相比&#xff0c;OLED 具有自發光、視角廣、響應速度快、對比度高、功耗低等優點。在嵌入式系統中&#xff0c;OLED 顯示屏常被用…

Web開發-JavaEE應用動態接口代理原生反序列化危險Invoke重寫方法利用鏈

知識點&#xff1a; 1、安全開發-JavaEE-動態代理&序列化&反序列化 2、安全開發-JavaEE-readObject&toString方法 一、演示案例-WEB開發-JavaEE-動態代理 動態代理 代理模式Java當中最常用的設計模式之一。其特征是代理類與委托類有同樣的接口&#xff0c;代理類…

K8s是常用命令和解釋

K8s高頻命令 獲取資源信息&#xff0c;如獲取 Pod、Service、Deployment等資源狀態信息 kubectl get創建資源如創建Pod、Service、Deployment等資源 kubectl create刪除資源&#xff0c;如刪除Pod、Service、Deployment等資源 kubectl delete 應用配置文件&#xff0c;如引用D…

【模態分解】EMD-經驗模態分解

算法配置頁面&#xff0c;也可以一鍵導出結果數據 報表自定義繪制 獲取和下載【PHM學習軟件PHM源碼】的方式 獲取方式&#xff1a;Docshttps://jcn362s9p4t8.feishu.cn/wiki/A0NXwPxY3ie1cGkOy08cru6vnvc

TDengine 語言連接器(Go)

簡介 driver-go 是 TDengine 的官方 Go 語言連接器&#xff0c;實現了 Go 語言 database/sql 包的接口。Go 開發人員可以通過它開發存取 TDengine 集群數據的應用軟件。 Go 版本兼容性 支持 Go 1.14 及以上版本。 支持的平臺 原生連接支持的平臺和 TDengine 客戶端驅動支持…

鏈接世界:計算機網絡的核心與前沿

計算機網絡引言 在數字化時代&#xff0c;計算機網絡已經成為我們日常生活和工作中不可或缺的基礎設施。從簡單的局域網&#xff08;LAN&#xff09;到全球互聯網&#xff0c;計算機網絡將數以億計的設備連接在一起&#xff0c;推動了信息交換、資源共享以及全球化的進程。 什…

AI agents系列之全面介紹

隨著大型語言模型(LLMs)的出現,人工智能(AI)取得了巨大的飛躍。這些強大的系統徹底改變了自然語言處理,但當它們與代理能力結合時,才真正釋放出潛力——能夠自主地推理、規劃和行動。這就是LLM代理大顯身手的地方,它們代表了我們與AI交互以及利用AI的方式的范式轉變。 …

如何使用AI輔助開發CSS3 - 通義靈碼功能全解析

一、引言 CSS3 作為最新的 CSS 標準&#xff0c;引入了眾多新特性&#xff0c;如彈性布局、網格布局等&#xff0c;極大地豐富了網頁樣式的設計能力。然而&#xff0c;CSS3 的樣式規則繁多&#xff0c;記憶所有規則對于開發者來說幾乎是不可能的任務。在實際開發中&#xff0c…

復刻系列-星穹鐵道 3.2 版本先行展示頁

復刻星穹鐵道 3.2 版本先行展示頁 0. 視頻 手搓&#xff5e;星穹鐵道&#xff5e;展示頁&#xff5e;&#xff5e;&#xff5e; 1. 基本信息 作者: 啊是特嗷桃系列: 復刻系列官方的網站: 《崩壞&#xff1a;星穹鐵道》3.2版本「走過安眠地的花叢」專題展示頁現已上線復刻的網…