第TR3周:Pytorch復現Transformer

  • ???🍨?本文為🔗365天深度學習訓練營中的學習記錄博客
  • ? ?🍖?原作者:K同學啊

Transformer通過自注意力機制,改變了序列建模的方式,成為AI領域的基礎架構

編碼器:理解輸入,提取上下文特征。

解碼器:基于編碼特征,按順序生成輸出。


1.多頭注意力機制

import math
import torch
import torch.nn as nndevice=torch.device('cuda' if torch.cuda.is_available() else 'cpu')class MultiHeadAttention(nn.Module):# n_heads:多頭注意力的數量# hid_dim:每個詞輸出的向量維度def __init__(self,hid_dim,n_heads):super(MultiHeadAttention,self).__init__()self.hid_dim=hid_dimself.n_heads=n_heads#強制hid_dim必須整除 hassert hid_dim % n_heads == 0#定義W_q矩陣ceself.w_q=nn.Linear(hid_dim,hid_dim)#定義W_k矩陣self.w_k=nn.Linear(hid_dim,hid_dim)#定義W_v矩陣self.w_v=nn.Linear(hid_dim,hid_dim)self.fc =nn.Linear(hid_dim,hid_dim)#縮放self.scale=torch.sqrt(torch.FloatTensor([hid_dim//n_heads]))def forward(self,query,key,value,mask=None):#Q,K,V的在句子這長度這一個維度的數值可以不一樣,可以一樣#K:[64,10,300],假設batch_size為64,有10個詞,每個詞的Query向量是300維bsz=query.shape[0]Q  =self.w_q(query)K  =self.w_k(key)V  =self.w_v(value)#這里把K Q V 矩陣拆分為多組注意力#最后一維就是是用self.hid_dim // self.n_heads 來得到的,表示每組注意力的向量長度,每個head的向量長度是:300/6=50#64表示batch size,6表示有6組注意力,10表示有10詞,50表示每組注意力的詞的向量長度#K: [64,10,300] 拆分多組注意力 -> [64,10,6,50] 轉置得到 -> [64,6,10,50]#轉置是為了把注意力的數量6放在前面,把10和50放在后面,方便下面計算Q=Q.view(bsz,-1,self.n_heads,self.hid_dim//self.n_heads).permute(0,2,1,3)K=K.view(bsz,-1,self.n_heads,self.hid_dim//self.n_heads).permute(0,2,1,3)V=V.view(bsz,-1,self.n_heads,self.hid_dim//self.n_heads).permute(0,2,1,3)#Q乘以K的轉置,除以scale#[64,6,12,50]*[64,6,50,10]=[64,6,12,10]#attention:[64,6,12,10]attention=torch.matmul(Q,K.permute(0,1,3,2)) / self.scale#如果mask不為空,那么就把mask為0的位置的attention分數設置為-1e10,這里用‘0’來指示哪些位置的詞向量不能被attention到,比如padding位置if mask is not None:attention=attention.masked_fill(mask==0,-1e10)#第二步:計算上一步結果的softmax,再經過dropout,得到attention#注意,這里是對最后一維做softmax,也就是在輸入序列的維度做softmax#attention: [64,6,12,10]attention=torch.softmax(attention,dim=-1)#第三步,attention結果與V相乘,得到多頭注意力的結果#[64,6,12,10] * [64,6,10,50] =[64,6,12,50]# x: [64,6,12,50]x=torch.matmul(attention,V)#因為query有12個詞,所以把12放在前面,把50和6放在后面,方便下面拼接多組的結果#x: [64,6,12,50] 轉置 -> [64,12,6,50]x=x.permute(0,2,1,3).contiguous()#這里的矩陣轉換就是:把多頭注意力的結果拼接起來#最后結果就是[64,12,300]# x:[64,12,6,50] -> [64,12,300]x=x.view(bsz,-1,self.n_heads*(self.hid_dim//self.n_heads))x=self.fc(x)return x

2.前饋傳播

class Feedforward(nn.Module):def __init__(self,d_model,d_ff,dropout=0.1):super(Feedforward,self).__init__()#兩層線性映射和激活函數self.linear1=nn.Linear(d_model,d_ff)self.dropout=nn.Dropout(dropout)self.linear2=nn.Linear(d_ff,d_model)def forward(self,x):x=torch.nn.functional.relu(self.linear1(x))x=self.dropout(x)x=self.linear2(x)return x

3.位置編碼

class PositionalEncoding(nn.Module):"實現位置編碼"def __init__(self, d_model, dropout=0.1, max_len=5000):super(PositionalEncoding, self).__init__()self.dropout = nn.Dropout(p=dropout)# 初始化Shape為(max_len, d_model)的PE (positional encoding)pe = torch.zeros(max_len, d_model).to(device)# 初始化一個tensor [[0, 1, 2, 3, ...]]position = torch.arange(0, max_len).unsqueeze(1)# 這里就是sin和cos括號中的內容,通過e和ln進行了變換div_term = torch.exp(torch.arange(0, d_model, 2) * -(math.log(10000.0) / d_model))pe[:, 0::2] = torch.sin(position * div_term) # 計算PE(pos, 2i)pe[:, 1::2] = torch.cos(position * div_term) # 計算PE(pos, 2i+1)pe = pe.unsqueeze(0) # 為了方便計算,在最外面在unsqueeze出一個batch# 如果一個參數不參與梯度下降,但又希望保存model的時候將其保存下來# 這個時候就可以用register_bufferself.register_buffer("pe", pe)def forward(self, x):"""x 為embedding后的inputs,例如(1,7, 128),batch size為1,7個單詞,單詞維度為128"""# 將x和positional encoding相加。x = x + self.pe[:, :x.size(1)].requires_grad_(False)return self.dropout(x)

4.編碼層

class EncoderLayer(nn.Module):def __init__(self,d_model,n_heads,d_ff,dropout=0.1):super(EncoderLayer,self).__init__()#編碼器層包含自注意機制和前饋神經網絡self.self_attn=MultiHeadAttention(d_model,n_heads)self.feedforward=Feedforward(d_model,d_ff,dropout)self.norm1=nn.LayerNorm(d_model)self.norm2=nn.LayerNorm(d_model)self.dropout=nn.Dropout(dropout)def forward(self,x,mask):#自注意力機制atten_output=self.self_attn(x,x,x,mask)x=x+self.dropout(atten_output)x=self.norm1(x)#前饋神經網絡ff_output=self.feedforward(x)x=x+self.dropout(ff_output)x=self.norm2(x)return x

5.解碼層

class DecoderLayer(nn.Module):def __init__(self, d_model, n_heads, d_ff, dropout=0.1):super(DecoderLayer, self).__init__()# 解碼器層包含自注意力機制、編碼器-解碼器注意力機制和前饋神經網絡self.self_attn   = MultiHeadAttention(d_model, n_heads)self.enc_attn    = MultiHeadAttention(d_model, n_heads)self.feedforward = Feedforward(d_model, d_ff, dropout)self.norm1   = nn.LayerNorm(d_model)self.norm2   = nn.LayerNorm(d_model)self.norm3   = nn.LayerNorm(d_model)self.dropout = nn.Dropout(dropout)def forward(self, x, enc_output, self_mask, context_mask):# 自注意力機制attn_output = self.self_attn(x, x, x, self_mask)x           = x + self.dropout(attn_output)x           = self.norm1(x)# 編碼器-解碼器注意力機制attn_output = self.enc_attn(x, enc_output, enc_output, context_mask)x           = x + self.dropout(attn_output)x           = self.norm2(x)# 前饋神經網絡ff_output = self.feedforward(x)x = x + self.dropout(ff_output)x = self.norm3(x)return x

6.Transformer模型構建

class Transformer(nn.Module):def __init__(self, vocab_size, d_model, n_heads, n_encoder_layers, n_decoder_layers, d_ff, dropout=0.1):super(Transformer, self).__init__()# Transformer 模型包含詞嵌入、位置編碼、編碼器和解碼器self.embedding           = nn.Embedding(vocab_size, d_model)self.positional_encoding = PositionalEncoding(d_model)self.encoder_layers      = nn.ModuleList([EncoderLayer(d_model, n_heads, d_ff, dropout) for _ in range(n_encoder_layers)])self.decoder_layers      = nn.ModuleList([DecoderLayer(d_model, n_heads, d_ff, dropout) for _ in range(n_decoder_layers)])self.fc_out              = nn.Linear(d_model, vocab_size)self.dropout             = nn.Dropout(dropout)def forward(self, src, trg, src_mask, trg_mask):# 詞嵌入和位置編碼src = self.embedding(src)src = self.positional_encoding(src)trg = self.embedding(trg)trg = self.positional_encoding(trg)# 編碼器for layer in self.encoder_layers:src = layer(src, src_mask)# 解碼器for layer in self.decoder_layers:trg = layer(trg, src, trg_mask, src_mask)# 輸出層output = self.fc_out(trg)return output

vocab_size = 10000
d_model    = 128
n_heads    = 8
n_encoder_layers = 6
n_decoder_layers = 6
d_ff             = 2048
dropout          = 0.1device = torch.device('cpu')transformer_model = Transformer(vocab_size, d_model, n_heads, n_encoder_layers, n_decoder_layers, d_ff, dropout)# 定義輸入
src = torch.randint(0, vocab_size, (32, 10))  # 源語言句子
trg = torch.randint(0, vocab_size, (32, 20))  # 目標語言句子
src_mask = (src != 0).unsqueeze(1).unsqueeze(2)  # 掩碼,用于屏蔽填充的位置
trg_mask = (trg != 0).unsqueeze(1).unsqueeze(2)  # 掩碼,用于屏蔽填充的位置# 模型前向傳播
output = transformer_model(src, trg, src_mask, trg_mask)
print(output.shape)

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/web/71577.shtml
繁體地址,請注明出處:http://hk.pswp.cn/web/71577.shtml
英文地址,請注明出處:http://en.pswp.cn/web/71577.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

FreeRTOS 任務間通信機制:隊列、信號量、事件標志組詳解與實驗

1. FreeRTOS 消息隊列 1.1 簡介 ? 隊列是 任務間通信的主要形式,可用于在任務之間以及中斷與任務之間傳遞消息。隊列在 FreeRTOS 中具有以下關鍵特點: 隊列默認采用 先進先出 FIFO 方式,也可以使用 xQueueSendToFront()實現 LIFO。FreeRT…

【虛擬化】Docker Desktop 架構簡介

在閱讀前您需要了解 docker 架構:Docker architecture WSL 技術:什么是 WSL 2 1.Hyper-V backend 我們知道,Docker Desktop 最開始的架構的后端是采用的 Hyper-V。 Docker daemon (dockerd) 運行在一個 Linux distro (LinuxKit build) 中&…

Unity光照之Halo組件

簡介 Halo 組件 是一種用于在游戲中創建光暈效果的工具,主要用于模擬光源周圍的發光區域(如太陽、燈泡等)或物體表面的光線反射擴散效果。 核心功能 1.光暈生成 Halo 組件會在光源或物體的周圍生成一個圓形光暈,模擬光線在空氣…

Flink深入淺出之01:應用場景、基本架構、部署模式

Flink 1?? 一 、知識要點 📖 1. Flink簡介 Apache Flink — Stateful Computations over Data StreamsApache Flink 是一個分布式大數據處理引擎,可對有界數據流和無界數據流進行有狀態的計算。Flink 能在所有常見集群環境中運行,并能以…

2025年【高壓電工】報名考試及高壓電工考試總結

隨著電力行業的快速發展,高壓電工成為確保電力系統安全穩定運行的重要一環。為了提高高壓電工的專業技能和安全意識,“安全生產模擬考試一點通”平臺特別整理了2025年高壓電工報名考試的相關信息及考試總結,并提供了一套完整的題庫&#xff0…

網絡HTTP

HTTP Network Request Library A Retrofit-based HTTP network request encapsulation library that provides simple and easy-to-use API interfaces with complete network request functionality. 基于Retrofit的HTTP網絡請求封裝庫,提供簡單易用的API接口和完…

os-copilot安裝和使用體驗測評

簡介: OS Copilot是阿里云基于大模型構建的Linux系統智能助手,支持自然語言問答、命令執行和系統運維調優。本文介紹其產品優勢、功能及使用方法,并分享個人開發者在云服務器資源管理中的實際應用體驗。通過-t/-f/管道功能,OS Cop…

Python Flask框架學習匯編

1、入門級: 《Python Flask Web 框架入門》 這篇博文條理清晰,由簡入繁,案例豐富,分十五節詳細講解了Flask框架,強烈推薦! 《python的簡單web框架flask【附例子】》 講解的特別清楚,每一步都…

【HarmonyOS Next之旅】DevEco Studio使用指南(一)

目錄 1 -> 工具簡介 1.1 -> 概述 1.2 -> HarmonyOS應用/服務開發流程 1.2.1 -> 開發準備 1.2.2 -> 開發應用/服務 1.2.3 -> 運行、調試和測試應用/服務 1.2.4 -> 發布應用/服務 2 -> 工程介紹 2.1 -> APP包結構 2.2 -> 切換工程視圖 …

Manus開源平替-開源通用智能體

原文鏈接:https://i68.ltd/notes/posts/250306-opensource-agi-agent/ OWL-比Manus還強的全能開源Agent OWL: Optimized Workforce Learning for General Multi-Agent Assistance in Real-World Task Automation,現實世界中執行自動化任務的通用多代理輔助優化學習…

【3.2-3.8學習周報】

提示:文章寫完后,目錄可以自動生成,如何生成可參考右邊的幫助文檔 文章目錄 摘要Abstract一、方法介紹1.任務適應性持續預訓練(TACP)2.領域自適應連續預訓練(DACP)3.ETS-DACP和ETA-DACP 二、實驗…

【Linux】用戶和組

思考 使用useradd在Linux下面創建一個用戶,默認情況下,會自動創建一個同名組,并且加入其中,那么是先創建用戶呢?還是先創建組呢? 很簡單,我們實踐一下不就知道了,如下所示&#xff…

新編大學應用英語綜合教程2 U校園全套參考答案

全套答案獲取: 鏈接:https://pan.quark.cn/s/389618f53143

SAP 顧問的五年職業規劃

SAP 顧問的職業發展受到技術進步、企業需求變化和全球經濟環境的影響,因此制定長遠規劃充滿挑戰。面對 SAP 產品路線圖的不確定性,如向 S/4HANA 和 Business Technology Platform (BTP) 的轉變,顧問必須具備靈活性,以保持競爭力和…

圖像生成-ICCV2019-SinGAN: Learning a Generative Model from a Single Natural Image

圖像生成-ICCV2019-SinGAN: Learning a Generative Model from a Single Natural Image 文章目錄 圖像生成-ICCV2019-SinGAN: Learning a Generative Model from a Single Natural Image主要創新點模型架構圖生成器生成器源碼 判別器判別器源碼 損失函數需要源碼講解的私信我 S…

Networking Based ISAC Hardware Testbed and Performance Evaluation

文章目錄 Applications and Challenges of Networked SensingCooperation Mechanism in Networked SensingChallenges and Key Enabling Technologies 5G NR Frame Structure Based ISAC ApproachSignals Available for Radio SensingMulti-Dimensiona Resource Optimization S…

2025年主流原型工具測評:墨刀、Axure、Figma、Sketch

2025年主流原型工具測評:墨刀、Axure、Figma、Sketch 要說2025年國內產品經理使用的主流原型設計工具,當然是墨刀、Axure、Figma和Sketch了,但是很多剛入行的產品經理不了解自己適合哪些工具,本文將從核心優勢、局限短板、協作能…

我代表中國受邀在亞馬遜云科技全球云計算大會re:Invent中技術演講

大家好我是小李哥,本名叫李少奕,目前在一家金融行業公司擔任首席云計算工程師。去年5月很榮幸在全球千萬名開發者中被選為了全球亞馬遜云科技認證技術專家(AWS Hero),是近10年來大陸地區僅有的第9名大陸專家。同時作為…

LeetCode 解題思路 12(Hot 100)

解題思路: 定義三個指針: prev(前驅節點)、current(當前節點)、nextNode(臨時保存下一個節點)遍歷鏈表: 每次將 current.next 指向 prev,移動指針直到 curre…

Ubuntu搭建最簡單WEB服務器

安裝apache2 sudo apt install apache2 檢查狀態 $ sudo systemctl status apache2 ● apache2.service - The Apache HTTP ServerLoaded: loaded (/lib/systemd/system/apache2.service; enabled; vendor prese>Active: active (running) since Thu 2025-03-06 09:51:10…