PyTorch 實戰:Transformer 模型搭建全解析

Transformer 作為一種強大的序列到序列模型,憑借自注意力機制在諸多領域大放異彩。它能并行處理序列,有效捕捉上下文關系,其架構包含編碼器與解碼器,各由多層組件構成,涉及自注意力、前饋神經網絡、歸一化和 Dropout 等關鍵環節 。下面我們深入探討其核心要點,并結合代碼實現進行詳細解讀。

一、Transformer 核心公式與機制

(一)自注意力計算

自注意力機制是 Transformer 的核心,其計算基于公式\(Attention(Q, K, V)=softmax\left(\frac{Q K^{T}}{\sqrt{d_{k}}}\right) V\)?。其中,Q、K、V分別是查詢、鍵和值矩陣,由輸入X分別乘以對應的權重矩陣\(W_Q\)、\(W_K\)、\(W_V\)得到 。\(d_{k}\)表示鍵的維度,除以\(\sqrt{d_{k}}\)?,一方面可防止\(QK^{T}\)過大導致 softmax 計算溢出,另一方面能讓\(QK^{T}\)結果滿足均值為 0、方差為 1 的分布 。\(QK^{T}\)本質上是計算向量間的余弦相似度,反映向量方向上的相似程度。

(二)多頭注意力機制

多頭注意力機制將輸入x拆分為h份,獨立計算h組不同的線性投影得到各自的Q、K、V?,然后并行計算注意力,最后拼接h個注意力池化結果,并通過可學習的線性投影產生最終輸出。這種設計使每個頭能關注輸入的不同部分,增強了模型對復雜函數的表示能力。

(三)位置編碼

由于 Transformer 沒有循環結構,位置編碼用于保留序列中的位置信息,確保模型在處理序列時能感知元素的位置。

二、自注意力與多頭注意力的實現

(一)自注意力實現

在 PyTorch 中,自注意力模塊Self_Attention的實現如下:

python

import numpy as np
import torch
from torch import nnclass Self_Attention(nn.Module):def __init__(self, input_dim, dim_k, dim_v):super(Self_Attention, self).__init__()self.q = nn.Linear(input_dim, dim_k)self.k = nn.Linear(input_dim, dim_k)self.v = nn.Linear(input_dim, dim_v)self._norm_fact = 1 / np.sqrt(dim_k)def forward(self, x):Q = self.q(x) K = self.k(x) V = self.v(x) atten = nn.Softmax(dim=-1)(torch.bmm(Q, K.permute(0, 2, 1))) * self._norm_factoutput = torch.bmm(atten, V)return outputX = torch.randn(4, 3, 2)
self_atten = Self_Attention(2, 4, 5) 
res = self_atten(X)
print(res.shape) 

在這段代碼中,Self_Attention類繼承自nn.Module?。__init__方法初始化了線性層qkv?,并計算了歸一化因子_norm_fact?。forward方法實現了自注意力的計算過程,先通過線性層得到Q、K、V?,然后計算注意力權重atten?,最后得到輸出output?。

(二)多頭注意力實現

多頭注意力模塊Self_Attention_Muti_Head的實現如下:

python

import torch
import torch.nn as nnclass Self_Attention_Muti_Head(nn.Module):def __init__(self,input_dim,dim_k,dim_v,nums_head):super(Self_Attention_Muti_Head,self).__init__()assert dim_k % nums_head == 0assert dim_v % nums_head == 0self.q = nn.Linear(input_dim,dim_k)self.k = nn.Linear(input_dim,dim_k)self.v = nn.Linear(input_dim,dim_v)self.nums_head = nums_headself.dim_k = dim_kself.dim_v = dim_vself._norm_fact = 1 / np.sqrt(dim_k)def forward(self,x):Q = self.q(x).reshape(-1,x.shape[0],x.shape[1],self.dim_k // self.nums_head) K = self.k(x).reshape(-1,x.shape[0],x.shape[1],self.dim_k // self.nums_head) V = self.v(x).reshape(-1,x.shape[0],x.shape[1],self.dim_v // self.nums_head)atten = nn.Softmax(dim=-1)(torch.matmul(Q,K.permute(0,1,3,2))) output = torch.matmul(atten,V).reshape(x.shape[0],x.shape[1],-1) return outputx=torch.rand(1,3,4)
atten=Self_Attention_Muti_Head(4,4,4,2)
y=atten(x)
print(y.shape) 

在這個類中,__init__方法進行了參數校驗和模塊初始化 。forward方法將輸入x經過線性層變換后,重塑形狀以適應多頭計算,接著計算注意力權重并得到輸出,最后將多頭的結果拼接起來。

三、注意力機制的拓展

(一)視覺注意力機制

視覺注意力機制主要包括空間域、通道域和混合域三種。空間域注意力通過對圖片空間域信息進行變換,生成掩碼并打分來提取關鍵信息;通道域注意力為每個通道分配權重,代表模塊有 SENet,通過全局池化提取通道權重,進而調整特征圖;混合域注意力則結合了空間域和通道域的信息 。

(二)通道域注意力(SENet)實現

SENet 的實現代碼如下:

python

class SELayer(nn.Module):def __init__(self, channel, reduction=16):super(SELayer, self).__init__()self.avg_pool = nn.AdaptiveAvgPool2d(1)self.fc = nn.Sequential(nn.Linear(channel, channel // reduction, bias=False),nn.ReLU(inplace=True),nn.Linear(channel // reduction, channel, bias=False),nn.Sigmoid())def forward(self, x):b, c, _, _ = x.size()y = self.avg_pool(x).view(b, c) y = self.fc(y).view(b, c, 1, 1) return x * y.expand_as(x)

SELayer類中,__init__方法初始化了平均池化層avg_pool和全連接層序列fc?。forward方法實現了 SENet 的核心操作,先通過平均池化進行 Squeeze 操作,再經過全連接層進行 Excitation 操作,最后將生成的權重與原特征圖相乘,實現對特征圖的增強 。

(三)門控注意力機制(GCT)

GCT 是一種能提升卷積網絡泛化能力的通道間建模結構。它包含全局上下文嵌入、通道規范化和門控適應三個部分。全局上下文嵌入模塊匯聚每個通道的全局上下文信息;通道規范化構建神經元競爭關系;門控適應加入門限機制,促進神經元的協同或競爭關系 。

其實現代碼如下:

python

class GCT(nn.Module):def __init__(self, num_channels, epsilon=1e-5, mode='l2', after_relu=False):super(GCT, self).__init__()self.alpha = nn.Parameter(torch.ones(1, num_channels, 1, 1))self.gamma = nn.Parameter(torch.zeros(1, num_channels, 1, 1))self.beta = nn.Parameter(torch.zeros(1, num_channels, 1, 1))self.epsilon = epsilonself.mode = modeself.after_relu = after_reludef forward(self, x):if self.mode == 'l2':embedding = (x.pow(2).sum((2, 3), keepdim=True) + self.epsilon).pow(0.5) * self.alphanorm = self.gamma / ((embedding.pow(2).mean(dim=1, keepdim=True) + self.epsilon).pow(0.5))elif self.mode == 'l1':if not self.after_relu:_x = torch.abs(x)else:_x = xembedding = _x.sum((2, 3), keepdim=True) * self.alphanorm = self.gamma / (torch.abs(embedding).mean(dim=1, keepdim=True) + self.epsilon)gate = 1. + torch.tanh(embedding * norm + self.beta)return x * gate

GCT類中,__init__方法初始化了可訓練參數alphagammabeta以及其他超參數 。forward方法根據不同的模式(l2l1)計算嵌入和歸一化結果,最后通過門控機制得到輸出 。GCT 通常添加在 Conv 層前,訓練時可先凍結原模型訓練 GCT,再解凍微調 。

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/web/76986.shtml
繁體地址,請注明出處:http://hk.pswp.cn/web/76986.shtml
英文地址,請注明出處:http://en.pswp.cn/web/76986.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

網頁不同渲染方式的應對與反爬機制的處理——python爬蟲

文章目錄 寫在前面爬蟲習慣web 網頁渲染方式服務器渲染客戶端渲染 反爬機制使用session對象使用cookie讓請求頭信息更豐富使用代理和隨機延遲 寫在前面 本文是對前兩篇文章所介紹的內容的補充,在了解前兩篇文章——《爬蟲入門與requests庫的使用》和《BeautifulSou…

RK3588平臺用v4l工具調試USB攝像頭實踐(亮度,飽和度,對比度,色相等)

目錄 前言:v4l-utils簡介 一:查找當前的攝像頭設備 二:查看當前攝像頭支持的v4l2-ctl調試參數 三根據提示設置對應參數,在提示范圍內設置 四:常用調試命令 五:應用內執行命令方法 前言:v4l-utils簡介 v4l-utils工具是由Linu…

Spring Security基礎入門

本入門案例主要演示Spring Security在Spring Boot中的安全管理效果。為了更好地使用Spring Boot整合實現Spring Security安全管理功能,體現案例中Authentication(認證)和Authorization(授權)功能的實現,本案…

Trae+DeepSeek學習Python開發MVC框架程序筆記(二):使用4個文件實現MVC框架

修改上節文件,將test2.py拆分為4個文件,目錄結構如下: mvctest/ │── model.py # 數據模型 │── view.py # 視圖界面 │── controller.py # 控制器 │── main.py # 程序入口其中model.py代碼如下&#xff…

從認證到透傳:用 Nginx 為 EasySearch 構建一體化認證網關

在構建本地或云端搜索引擎系統時,EasySearch 憑借其輕量、高性能、易部署等優勢,逐漸成為眾多開發者和技術愛好者的首選。但在實際部署過程中,如何借助 Nginx 為 EasySearch 提供高效、穩定且安全的訪問入口,尤其是在身份認證方面…

CPU 虛擬化機制——受限直接執行 (LDE)

1. 引言:CPU虛擬化的核心問題 讓多個進程看似同時運行在一個物理CPU上。核心思想是時分共享 (time sharing) CPU。為了實現高效且可控的時分共享,本章介紹了一種關鍵機制,稱為受限直接執行 (Limited Direct Execution, LDE)。 1.1 LDE的基本…

linux 中斷子系統鏈式中斷編程

直接貼代碼了&#xff1a; 虛擬中斷控制器代碼&#xff0c;chained_virt.c #include<linux/kernel.h> #include<linux/module.h> #include<linux/clk.h> #include<linux/err.h> #include<linux/init.h> #include<linux/interrupt.h> #inc…

容器修仙傳 我的靈根是Pod 第10章 心魔大劫(RBAC與SecurityContext)

第四卷&#xff1a;飛升之劫化神篇 第10章 心魔大劫&#xff08;RBAC與SecurityContext&#xff09; 血月當空&#xff0c;林衍的混沌靈根正在異變。 每道經脈都爬滿黑色紋路&#xff0c;神識海中回蕩著蠱惑之音&#xff1a;"破開藏經閣第九層禁制…奪取《太古弒仙訣》……

基于c#,wpf,ef框架,sql server數據庫,音樂播放器

詳細視頻: 【基于c#,wpf,ef框架,sql server數據庫&#xff0c;音樂播放器。-嗶哩嗶哩】 https://b23.tv/ZqmOKJ5

精益數據分析(21/126):剖析創業增長引擎與精益畫布指標

精益數據分析&#xff08;21/126&#xff09;&#xff1a;剖析創業增長引擎與精益畫布指標 大家好&#xff01;在創業和數據分析的探索道路上&#xff0c;我一直希望能和大家攜手共進&#xff0c;共同學習。今天&#xff0c;我們繼續深入研讀《精益數據分析》&#xff0c;剖析…

Spark-streaming核心編程

1.導入依賴?&#xff1a; <dependency> <groupId>org.apache.spark</groupId> <artifactId>spark-streaming-kafka-0-10_2.12</artifactId> <version>3.0.0</version> </dependency> 2.編寫代碼?&#xff1a; 創建Sp…

Kafka的ISR機制是什么?如何保證數據一致性?

一、Kafka ISR機制深度解析 1. ISR機制定義 ISR&#xff08;In-Sync Replicas&#xff09;是Kafka保證數據一致性的核心機制&#xff0c;由Leader副本&#xff08;復雜讀寫&#xff09;和Follower副本(負責備份)組成。當Follower副本的延遲超過replica.lag.time.max.ms&#…

Docker 基本概念與安裝指南

Docker 基本概念與安裝指南 一、Docker 核心概念 1. 容器&#xff08;Container&#xff09; 容器是 Docker 的核心運行單元&#xff0c;本質是一個輕量級的沙盒環境。它基于鏡像創建&#xff0c;包含應用程序及其運行所需的依賴&#xff08;如代碼、庫、環境變量等&#xf…

數據庫監控 | MongoDB監控全解析

PART 01 MongoDB&#xff1a;靈活、可擴展的文檔數據庫 MongoDB作為一款開源的NoSQL數據庫&#xff0c;憑借其靈活的數據模型&#xff08;基于BSON的文檔存儲&#xff09;、水平擴展能力&#xff08;分片集群&#xff09;和高可用性&#xff08;副本集架構&#xff09;&#x…

OpenFeign和Gateway

OpenFeign和Gateway 一.OpenFeign介紹二.快速上手1.引入依賴2.開啟openfeign的功能3.編寫客戶端4.修改遠程調用代碼5.測試 三.OpenFeign參數傳遞1.傳遞單個參數2.多個參數、傳遞對象和傳遞JSON字符串3.最佳方式寫代碼繼承的方式抽取的方式 四.部署OpenFeign五.統一服務入口-Gat…

spark-streaming(二)

DStream創建&#xff08;kafka數據源&#xff09; 1.在idea中的 pom.xml 中添加依賴 <dependency><groupId>org.apache.spark</groupId><artifactId>spark-streaming-kafka-0-10_2.12</artifactId><version>3.0.0</version> </…

JAVA聚焦OutOfMemoryError 異常

個人主頁 文章專欄 在正文開始前&#xff0c;我想多說幾句&#xff0c;也就是吐苦水吧…最近這段時間一直想寫點東西&#xff0c;停下來反思思考一下。 心中萬言&#xff0c;真正執筆時又不知先寫些什么。通常這個時候&#xff0c;我都會隨便寫寫&#xff0c;文風極像散文&…

如何在Spring Boot中配置自定義端口運行應用程序

Spring Boot 應用程序默認在端口 8080 上運行嵌入式 Web 服務器&#xff08;如 Tomcat、Jetty 或 Undertow&#xff09;。然而&#xff0c;在開發、測試或生產環境中&#xff0c;開發者可能需要將應用程序配置為在自定義端口上運行&#xff0c;例如避免端口沖突、適配微服務架構…

linux嵌入式(進程與線程1)

Linux進程 進程介紹 1. 進程的基本概念 定義&#xff1a;進程是程序的一次執行過程&#xff0c;擁有獨立的地址空間、資源&#xff08;如內存、文件描述符&#xff09;和唯一的進程 ID&#xff08;PID&#xff09;。 組成&#xff1a; 代碼段&#xff1a;程序的指令。 數據…

智馭未來:NVIDIA自動駕駛安全白皮書與實驗室創新實踐深度解析

一、引言&#xff1a;自動駕駛安全的范式革新 在當今數字化浪潮的推動下&#xff0c;全球自動駕駛技術正大步邁入商業化的深水區。隨著越來越多的自動駕駛車輛走上道路&#xff0c;其安全性已成為整個行業乃至社會關注的核心命題。在這個關鍵的轉折點上&#xff0c;NVIDIA 憑借…