從代碼學習深度學習 - Transformer PyTorch 版

文章目錄

  • 前言
  • 1. 位置編碼(Positional Encoding)
  • 2. 多頭注意力機制(Multi-Head Attention)
  • 3. 前饋網絡與殘差連接(Position-Wise FFN & AddNorm)
    • 3.1 基于位置的前饋網絡(PositionWiseFFN)
    • 3.2 殘差連接和層規范化(AddNorm)
  • 4. 編碼器(Encoder)
    • 4.1 編碼器塊(EncoderBlock)
    • 4.2 Transformer 編碼器(TransformerEncoder)
  • 5. 解碼器(Decoder)
    • 5.1 解碼器塊(DecoderBlock)
    • 5.2 Transformer 解碼器(TransformerDecoder)
  • 6. 完整 Transformer 模型
    • 使用示例
  • 總結


前言

Transformer 模型自 2017 年在論文《Attention is All You Need》中提出以來,徹底改變了自然語言處理(NLP)領域,并在計算機視覺等其他領域展現了強大的潛力。與傳統的 RNN 和 LSTM 相比,Transformer 通過自注意力機制(Self-Attention)實現了并行計算,極大地提高了訓練效率和模型性能。本博客將通過 PyTorch 實現的 Transformer 模型代碼,深入剖析其核心組件,包括多頭注意力機制、位置編碼、編碼器和解碼器等。我們將結合代碼和文字說明,逐步拆解 Transformer 的實現邏輯,幫助讀者從代碼層面理解這一經典模型的精髓。
在這里插入圖片描述

本文基于提供的代碼文件(PE.pyEnDecoder.pyMHA.pyTransformer.ipynb),完整呈現 Transformer 的 PyTorch 實現,并通過清晰的目錄結構和代碼注釋,帶領大家從零開始學習 Transformer 的構建過程。關于訓練和可視化部分,這里忽略掉,但是你仍然可以在下面的鏈接里找到所有的源代碼,其中提供了豐富的注釋。無論你是深度學習初學者還是希望深入理解 Transformer 的開發者,這篇博客都將為你提供一個清晰的學習路徑。

完整代碼:下載鏈接


1. 位置編碼(Positional Encoding)

Transformer 的自注意力機制不包含序列的位置信息,因此需要通過位置編碼(Positional Encoding)為每個詞元添加位置信息。以下是 PE.py 中實現的位置編碼類,它通過正弦和余弦函數生成固定位置編碼。

import torch
import torch.nn as nnclass PositionalEncoding(nn.Module):"""位置編碼在Transformer中,由于自注意力機制不含位置信息,需要額外添加位置編碼在位置嵌入矩陣P中,行代表詞元在序列中的位置,列代表位置編碼的不同維度"""def __init__(self, num_hiddens, dropout, max_len=1000):"""初始化位置編碼參數:num_hiddens (int): 隱藏層維度,即位置編碼的維度dropout (float): dropout概率max_len (int, 可選): 最大序列長度,默認為1000"""super(PositionalEncoding, self).__init__()# 初始化丟棄層self.dropout = nn.Dropout(dropout)# 創建位置編碼矩陣P,形狀為(1, max_len, num_hiddens)self.P = torch.zeros((1, max_len, num_hiddens))# 計算位置編碼的正弦和余弦函數輸入# X形狀: (max_len, num_hiddens/2)X = torch.arange(max_len, dtype=torch.float32).reshape(-1, 1) / torch.pow(10000, torch.arange(0, num_hiddens, 2, dtype=torch.float32) / num_hiddens)# 偶數維度賦值正弦,奇數維度賦值余弦self.P[:, :, 0::2] = torch.sin(X)self.P[:, :, 1::2] = torch.cos(X)def forward(self, X):"""前向傳播參數:X (torch.Tensor): 輸入張量,形狀為(batch_size, seq_len, embed_dim)返回:torch.Tensor: 添加位置編碼后的張量,形狀為(batch_size, seq_len, embed_dim)"""# 將位置編碼加到輸入X上,截取與X長度匹配的部分X = X + self.P[:, :X.shape[1], :].to(X.device)# 應用丟棄并返回結果return self.dropout(X)

代碼解析

  • 初始化PositionalEncoding 類根據隱藏層維度(num_hiddens)和最大序列長度(max_len)生成一個位置編碼矩陣 P。該矩陣的每一行表示一個位置,每一列對應一個編碼維度。
  • 正弦和余弦編碼:通過正弦(sin)和余弦(cos)函數為不同位置和維度生成編碼值,公式為:
    P E ( p o s , 2 i ) = sin ? ( p o s 1000 0 2 i / d ) , P E ( p o s , 2 i + 1 ) = cos ? ( p o s 1000 0 2 i / d ) PE(pos, 2i) = \sin\left(\frac{pos}{10000^{2i/d}}\right), \quad PE(pos, 2i+1) = \cos\left(\frac{pos}{10000^{2i/d}}\right) PE(pos,2i)=sin(100002i/dpos?),PE(pos,2i+1)=cos(100002i/dpos?)
    其中 pos 是位置索引,i 是維度索引,d 是隱藏層維度。
  • 前向傳播:將輸入張量 X 與位置編碼矩陣 P 相加,并應用 dropout 以增強模型的魯棒性。

位置編碼的作用是將序列的位置信息嵌入到詞嵌入中,使得 Transformer 能夠區分相同詞元在不同位置的語義。

2. 多頭注意力機制(Multi-Head Attention)

多頭注意力機制是 Transformer 的核心組件,允許模型并行計算多個注意力頭,從而捕獲序列中不同方面的依賴關系。以下是 MHA.py 中實現的多頭注意力機制。

import math
import torch
from torch import nn
import torch.nn.functional as Fdef sequence_mask(X, valid_len, value=0):"""在序列中屏蔽不相關的項,使超出有效長度的位置被設置為指定值"""maxlen = X.size(1)mask = torch.arange(maxlen, dtype=torch

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/web/76112.shtml
繁體地址,請注明出處:http://hk.pswp.cn/web/76112.shtml
英文地址,請注明出處:http://en.pswp.cn/web/76112.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

閱讀分析Linux0.11 /boot/head.s

目錄 初始化IDT、IDTR和GDT、GDTR檢查協處理器并設置CR0寄存器初始化頁表和CR3寄存器,開啟分頁 初始化IDT、IDTR和GDT、GDTR startup_32:movl $0x10,%eaxmov %ax,%dsmov %ax,%esmov %ax,%fsmov %ax,%gslss _stack_start,%espcall setup_idtcall setup_gdtmovl $0x1…

33、單元測試實戰練習題

以下是三個練習題的具體實現方案,包含完整代碼示例和詳細說明: 練習題1:TDD實現博客評論功能 步驟1:編寫失敗測試 # tests/test_blog.py import unittest from blog import BlogPost, Comment, InvalidCommentErrorclass TestBl…

16-算法打卡-哈希表-兩個數組的交集-leetcode(349)-第十六天

1 題目地址 349. 兩個數組的交集 - 力扣(LeetCode)349. 兩個數組的交集 - 給定兩個數組 nums1 和 nums2 ,返回 它們的 交集 。輸出結果中的每個元素一定是 唯一 的。我們可以 不考慮輸出結果的順序 。 示例 1:輸入:nu…

SciPy庫詳解

SciPy 是一個用于數學、科學和工程計算的 Python 庫,它建立在 NumPy 之上,提供了許多高效的算法和工具,用于解決各種科學計算問題。 CONTENT 1. 數值積分功能代碼 2. 優化問題求解功能代碼3. 線性代數運算功能代碼 4. 信號處理功能代碼 5. 插…

杰弗里·辛頓:深度學習教父

名人說:路漫漫其修遠兮,吾將上下而求索。—— 屈原《離騷》 創作者:Code_流蘇(CSDN)(一個喜歡古詩詞和編程的Coder😊) 杰弗里辛頓:當堅持遇見突破,AI迎來新紀元 一、人物簡介 杰弗…

BladeX單點登錄與若依框架集成實現

1. 概述 本文檔詳細介紹了將BladeX認證系統與若依(RuoYi)框架集成的完整實現過程。集成采用OAuth2.0授權碼流程,使用戶能夠通過BladeX賬號直接登錄若依系統,實現無縫單點登錄體驗。 2. 系統架構 2.1 總體架構 #mermaid-svg-YxdmBwBtzGqZHMme {font-fa…

初識Redis · set和zset

目錄 前言: set 基本命令 交集并集差集 內部編碼和應用場景 zset 基本命令 交集并集差集 內部編碼和應用場景 應用場景(AI生成) 排行榜系統 應用背景 設計思路 熱榜系統 應用背景 設計思路 熱度計算方式 總結對比表 前言&a…

playwright 教程高級篇:掌握網頁自動化與驗證碼處理等關鍵技術詳解

Playwright 教程高級篇:掌握網頁自動化與驗證碼處理等關鍵技術詳解 本教程將帶您一步步學習如何使用 Playwright——一個強大的瀏覽器自動化工具,來完成網頁任務,例如提交鏈接并處理旋轉驗證碼。我們將按照典型的自動化流程順序,從啟動瀏覽器到關閉瀏覽器,詳細講解每個步驟…

數據結構(完)

樹 二叉樹 構建二叉樹 int value;Node left;Node right;public Node(int val) {valueval;} 節點的添加 Node rootnull;public void insert(int num) {Node nodenew Node(num);if(rootnull) {rootnode;return;}Node index root;while(true) {//插入的節點值小if(index.value&g…

FastAPI與SQLAlchemy數據庫集成與CRUD操作

title: FastAPI與SQLAlchemy數據庫集成與CRUD操作 date: 2025/04/16 09:50:57 updated: 2025/04/16 09:50:57 author: cmdragon excerpt: FastAPI與SQLAlchemy集成基礎包括環境準備、數據庫連接配置和模型定義。CRUD操作通過數據訪問層封裝和路由層實現,確保線程安全和事務…

一個基于Django的寫字樓管理系統實現方案

一個基于Django的寫字樓管理系統實現方案 用戶現在需要我用Django來編寫一個寫字樓管理系統的Web版本,要求包括增刪改查寫字樓的HTML頁面,視頻管理功能,本地化部署,以及人員權限管理,包含完整的代碼結構和功能實現&am…

mongodb在window10中創建副本集的方法,以及node.js連接副本集的方法

創建Mongodb的副本集最好是新建一個文件夾,如D:/data,不要在mongodb安裝文件夾里面創建副本集,雖然這樣也可以,但是容易造成誤操作或路徑混亂;在新建文件夾里與現有 MongoDB 數據隔離,避免誤操作影響原有數…

Maven 多倉庫與鏡像配置全攻略:從原理到企業級實踐

Maven 多倉庫與鏡像配置全攻略:從原理到企業級實踐 一、核心概念:Repository 與 Mirror 的本質差異 在 Maven 依賴管理體系中,repository與mirror是構建可靠依賴解析鏈的兩大核心組件,其核心區別如下: 1. Repositor…

STM32 四足機器人常見問題匯總

文章不介紹具體參數,有需求可去網上搜索。 特別聲明:不論年齡,不看學歷。既然你對這個領域的東西感興趣,就應該不斷培養自己提出問題、思考問題、探索答案的能力。 提出問題:提出問題時,應說明是哪款產品&a…

MySQL 中 `${}` 和 `#{}` 占位符詳解及面試高頻考點

文章目錄 一、概述二、#{} 和 ${} 的核心區別1. 底層機制代碼示例 2. 核心區別總結 三、為什么表名只能用 ${}?1. 預編譯機制的限制2. 動態表名的實現 四、安全性注意事項1. ${} 的風險場景2. 安全實踐 五、面試高頻考點1. 基礎原理類問題**問題 1**:**問…

C語言編譯預處理2

#include <XXXX.h>和#include <XXXX.c> #include "XXXX.h" 是 C 語言中一條預處理指令 #include <XXXX.h>&#xff1a;這種形式用于包含系統標準庫的頭文件。預處理器會在系統默認的頭文件搜索路徑中查找XXXX.h 文件。例如在 Linux 系統中&#…

Elasticvue-輕量級Elasticsearch可視化管理工具

Elasticvue一個免費且開源的 Elasticsearch 在線可視化客戶端&#xff0c;用于管理 Elasticsearch 集群中的數據&#xff0c;完全支持 Elasticsearch 版本 8.x 和 7.x. 功能特色&#xff1a; 集群概覽索引和別名管理分片管理搜索和編輯文檔REST 查詢快照和存儲庫管理支持國際…

Git提交規范及最佳實踐

Git 提交規范通常是為了提高代碼提交的可讀性、可維護性和自動化效率&#xff08;如生成 ChangeLog&#xff09;。以下是常見的 Conventional Commits 規范&#xff0c;結合社區最佳實踐總結而成&#xff1a; 1. 提交格式 每次提交的 commit message 應包含三部分&#xff1a;…

Ubuntu中snap

通過Snap可以安裝眾多的軟件包。需要注意的是&#xff0c;snap是一種全新的軟件包管理方式&#xff0c;它類似一個容器擁有一個應用程序所有的文件和庫&#xff0c;各個應用程序之間完全獨立。所以使用snap包的好處就是它解決了應用程序之間的依賴問題&#xff0c;使應用程序之…

android studio 運行java main報錯

運行某個帶main函數的java文件報錯 Could not create task :app:Test.main(). > SourceSet with name main not found. 解決辦法&#xff1a;在工程的.idea/gradle.xml 文件下添加&#xff1a; <option name"delegatedBuild" value"false" /&g…