從0搭建Transformer

1. 位置編碼模塊:

import torch
import torch.nn as nn
import mathclass PositonalEncoding(nn.Module):def __init__ (self, d_model, dropout, max_len=5000):super(PositionalEncoding, self).__init__()self.dropout = nn.Dropout(p=dropout)# [[1, 2, 3],# [4, 5, 6],# [7, 8, 9]]pe = torch.zeros(max_len, d_model)# [[0],# [1],# [2]]position = torch.arange(0, max_len, dtype = torch.float).unsqueeze(1)div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))pe[:, 0::2] = torch.sin(position * div_term)pe[:, 1::2] = torch.cos(position * div_term)pe = pe.unsqueeze(0)# 位置編碼固定,不更新參數# 保存模型時會保存緩沖區,在引入模型時緩沖區也被引入self.register_buffer('pe', pe)def forward(self, x):# 不計算梯度x = x + self.pe[:, :x.size(1)].requires_grad_(False)

2. 多頭注意力模塊

class MultiHeadAttention(nn.Module):def __init__(self, d_model, num_heads):super(MultiHeadAttention, self).__init__()assert d_model % num_heads == 0self.d_k = d_model // num_headsself.num_heads = num_headsself.W_q = nn.Linear(d_model, d_model)self.W_k = nn.Linear(d_model, d_model)self.W_v = nn.Linear(d_model, d_model)self.dropout = nn.Dropout(dropout)self.W_o = nn.Linear(d_model, d_model)def forward(self, query, key, value, mask=None):batch_size = query.size(0)Q = self.W_q(query).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)K = self.W_k(key).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)V = self.W_v(value).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)if mask is not None:scores = scores.masked_fill(mask == 0, -1e9)attn_weights = torch.softmax(scores, dim=-1)context = torch.matmul(attn_weights, V)context = context.transpose(1, 2).contiguous().view(batch_size, -1, self.d_k * self.num_heads)return self.W_o(context)

3. 編碼器層

class EncoderLayer(nn.Module):def __init__(self, d_model, num_heads, d_ff, dropout = 0.1):super().__init__()self.atten = MultiHeadAttention(d_model, num_heads)self.feed_forward = nn.Sequential(nn.Linear(d_model, d_ff),nn.ReLU(),nn.Linear(d_ff, d_model))self.norm1 = nn.LayerNorm(d_model)self.norm2 = nn.LayerNorm(d_model)self.dropout = nn.Dropout(dropout)def forward(self, x, mask=None):attn_output = self.attn(x, x, x, mask)x = self.norm1(x + self.dropout(attn_output))ff_output = self.feed_forward(x)x = self.norm2(x + self.dropout(ff_output))return x

4. 解碼器層

class DecoderLayer(nn.Module):def __init__(self, d_model, num_heads, d_ff, dropout=0.1):super(DecoderLayer, self).__init__()self.self_attn = MultiHeadAttention(d_model, num_heads, dropout)self.cross_attn = MultiHeadAttention(d_model, num_heads, dropout)self.norm1 = nn.LayerNorm(d_model)self.norm2 = nn.LayerNorm(d_model)self.norm3 = nn.LayerNorm(d_model)self.feed_forward = nn.Sequential(nn.Linear(d_model, d_ff),nn.ReLU(),nn.Linear(d_ff, d_model))self.dropout = nn.Dropout(dropout)def forward(self, x, enc_output, src_mask, tgt_mask):attn_output = self.self_attn(x, x, x, tgt_mask)x = self.norm1(x + self.dropout(attn_output))attn_output = self.cross_attn(x, enc_output, enc_output, src_mask)x = self.norm2(x + self.dropout(attn_output))ff_output = self.feed_forward(x)x = self.norm3(x + self.dropout(ff_output))return x

5. 模型整合

class Transformer(nn.module):def __init__(self, src_vocab_size, tgt_vocab_size, d_model=512, num_heads=8, num_layers=6, d_ff=2048, dropout=0.1):super(Transformer, self).__init__()self.encoder_embed = nn.Embedding(src_vocab_size, d_model)self.decoder_embed = nn.Embedding(tgt_vocab_size, d_model)self.pos_encoder = PositionalEncoding(d_model, dropout)self.encoder_layers = nn.ModuleList([EncoderLayer(d_model, num_heads, d_ff, dropout) for _ in range(num_layers)])self.decoder_layers = nn.ModuleList([DecoderLayer(d_model, num_heads, d_ff, dropout) for _ in range(num_layers)])self.fc_out = nn.Linear(d_model, tgt_vocab_size)def encode(self, src, src_mask):src_embeded = self.encoder_embed(src)src = self.pos_encoder(src_embeded)for layer in self.encoder_layers:src = layer(src, src_mask)return srcdef decode(self, tgt, enc_output, src_mask, tgt_mask):tgt_embeded = self.decoder_embed(tgt)tgt = self.pos_encoder(tgt_embeded)for layer in self.decoder_layers:tgt = layer(tgt, enc_output, src_mask, tgt_mask)return tgtdef forward(self, src, tgt, src_mask, tgt_mask):enc_output = self.encode(src, src_mask)dec_output = self.decode(tgt, enc_output, src_mask, tgt_mask)logits = self.fc_out(dec_output)return logits

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/pingmian/79219.shtml
繁體地址,請注明出處:http://hk.pswp.cn/pingmian/79219.shtml
英文地址,請注明出處:http://en.pswp.cn/pingmian/79219.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

【Bootstrap V4系列】學習入門教程之 表格(Tables)和畫像(Figure)

Bootstrap V4系列 學習入門教程之 表格(Tables)和畫像(Figure) 表格(Tables)一、Examples二、Table head options 表格頭選項三、Striped rows 條紋行四、Bordered table 帶邊框的表格五、Borderless table…

在C# WebApi 中使用 Nacos02: 配置管理、服務管理實戰

一、配置管理 1.添加一個新的命名空間 這里我都填寫為publicdemo 2.C#代碼配置啟動 appsetting.json加上: (nacos默認是8848端口) "NacosConfig": {"ServerAddresses": [ "http://localhost:8848" ], // Nacos 服務器地址"Na…

如何搭建spark yarn 模式的集群集群。

下載 App 如何搭建spark yarn 模式的集群集群。 搭建Spark on YARN集群的詳細步驟 Spark on YARN模式允許Spark作業在Hadoop YARN資源管理器上運行,利用YARN進行資源調度。以下是搭建步驟: 一、前提條件 已安裝并配置好的Hadoop集群(包括HDF…

C++--入門基礎

C入門基礎 1. C的第一個程序 C繼承C語言許多大多數的語法,所以以C語言實現的hello world也可以運行,C中需要把文件定義為.cpp,vs編譯器看是.cpp就會調用C編譯器編譯,linux下要用g編譯,不再是gcc。 // test.cpp #inc…

從實列中學習linux shell9 如何確認 服務器反應遲鈍是因為cpu還是 硬盤io 到底是那個程序引起的。cpu負載多高算高

在 Linux 系統中,Load Average(平均負載) 是衡量系統整體壓力的關鍵指標,但它本身沒有絕對的“高/低”閾值,需要結合 CPU 核心數 和 其他性能指標 綜合分析。以下是具體判斷方法: 一、Load Average 的基本含義 定義:Load Average 表示 單位時間內處于可運行狀態(R)和不…

聊一聊接口測試更側重于哪方面的驗證

目錄 一、功能性驗證 輸入與輸出正確性 參數校驗 業務邏輯覆蓋 二、數據一致性驗證 數據格式規范 數據完整性 數據類型與范圍 三、異常場景驗證 容錯能力測試 邊界條件覆蓋 錯誤碼與信息清晰度 四、安全與權限驗證 身份認證 數據安全 防攻擊能力 五、性能與可…

Fiddler抓取APP端,HTTPS報錯全解析及解決方案(一篇解決常見問題)

環境:雷電模擬器Android9系統 ? 你所遇到的fiddler中抓取HTTPS的問題可以分為三類:一類是你自己證書安裝上邏輯錯誤,另一種是APP中使用了“證書固定”的手段。三類fiddler中生成證書時的參數過程。 1.Fiddler證書安裝上的邏輯錯誤 更新Opt…

OpenGL-ES 學習(15) ----紋理

目錄 紋理簡介紋理映射紋理映射流程示例代碼:紋理的環繞和過濾方式紋理的過濾方式 紋理簡介 現實生活中,紋理(Texture) 類似于游戲中皮膚的概念,最通常的作用是裝飾 3D 物體,它像貼紙一樣貼在物體的表面,豐富物體的表…

OpenCV計算機視覺實戰(2)——環境搭建與OpenCV簡介

OpenCV計算機視覺實戰(2)——環境搭建與OpenCV簡介 0. 前言1. OpenCV 安裝與配置1.1 安裝 Python-OpenCV1.2 配置開發環境 2. OpenCV 基礎2.1 圖像讀取與顯示2.2 圖像保存 3. 攝像頭實時捕獲小結系列鏈接 0. 前言 OpenCV (Open Source Computer Vision …

ubuntu22.04安裝顯卡驅動與cuda+cuDNN

背景: 緊接前文:Proxmox VE 8.4 顯卡直通完整指南:NVIDIA 2080 Ti 實戰。在R740服務器完成了proxmox的安裝,并且安裝了一張2080ti 魔改22g顯存的的顯卡。配置完了proxmox顯卡直通,并將顯卡掛載到了vm 301(…

A2A Python 教程 - 綜合指南

目錄 ? 介紹? 設置環境? 創建項目? 代理技能? 代理卡片? A2A服務器? 與A2A服務器交互? 添加代理功能? 使用本地Ollama模型? 后續步驟 介紹 在本教程中,您將使用Python構建一個簡單的echo A2A服務器。這個基礎實現將向您展示A2A提供的所有功能。完成本教…

MySQL基礎關鍵_005_DQL(四)

目 錄 一、分組函數 1.說明 2.max/min 3.sum/avg/count 二、分組查詢 1.說明 2.實例 (1)查詢崗位和平均薪資 (2)查詢每個部門編號的不同崗位的最低薪資 3.having (1)說明 (2&#xff…

GAMES202-高質量實時渲染(Assignment 2)

目錄 作業介紹環境光貼圖預計算傳輸項的預計算Diffuse unshadowedDiffuse shadowedDiffuse Inter-reflection(bonus) 實時球諧光照計算 GitHub主頁:https://github.com/sdpyy1 作業實現:https://github.com/sdpyy1/CppLearn/tree/main/games202 作業介紹 物體在不同…

2025年- H21-Lc129-160. 相交鏈表(鏈表)---java版

1.題目描述 2.思路 當pa!pb的時候,執行pa不為空,遍歷pa鏈表。執行pb不為空,遍歷pb鏈表。 3.代碼實現 // 單鏈表節點定義 class ListNode {int val;ListNode next;ListNode(int x){valx;nextnull;}}public class H160 {// 主方法…

win10系統安卓開發環境搭建

一 安裝jdk 下載jdk17 ,下載路徑:https://download.oracle.com/java/17/archive/jdk-17.0.12_windows-x64_bin.exe 下載完畢后,按照提示一步步完成,然后接著創建環境變量, 在cmd控制臺輸入java -version 驗證: 有上面的輸出代表jdk安裝并配置成功。 二 安裝Android stu…

【算法基礎】選擇排序算法 - JAVA

一、算法基礎 1.1 什么是選擇排序 選擇排序是一種簡單直觀的排序算法,它的工作原理是:首先在未排序序列中找到最小(或最大)元素,存放到排序序列的起始位置,然后再從剩余未排序元素中繼續尋找最小&#xf…

LabVIEW異步調用VI介紹

在 LabVIEW 編程環境里,借助結合異步 VI 調用,并使用 “Open VI Reference” 函數上的 “Enable simultaneous calls on reentrant VIs” 選項(0x40),達成了對多個 VI 調用執行效率的優化。以下將從多方面詳細介紹該 V…

Leetcode刷題 | Day50_圖論02_島嶼問題01_dfs兩種方法+bfs一種方法

一、學習任務 99. 島嶼數量_深搜dfs代碼隨想錄99. 島嶼數量_廣搜bfs100. 島嶼的最大面積101. 孤島的總面積 第一類DFS(主函數中處理第一個節點,DFS處理相連節點): 主函數中先將起始節點標記為已訪問DFS函數中不處理起始節點&…

深入理解網絡安全中的加密技術

1 引言 在當今數字化的世界中,網絡安全已經成為個人隱私保護、企業數據安全乃至國家安全的重要組成部分。隨著網絡攻擊的復雜性和頻率不斷增加,保護敏感信息不被未授權訪問變得尤為關鍵。加密技術作為保障信息安全的核心手段,通過將信息轉換為…